Comparison of EIS and the CEM for SSMs

Simplified version of regional model in Chapter 4.1, keeping only \(\log I_t\) and \(\log \rho_t\) in the states.

  • States \(X_t = \left(\log I_{t}, \log \rho_{t + 1}\right)\)
  • Observations \(Y_t | X_t \sim \operatorname{Pois} \left( \exp \log I_{t}\right)\)

Varying \(n = 10, 100, 1000\). Initialize \(\log \rho_0 = 0\) with small variance and \(\log I_0 = \log 1000\) with small variance as well.

Let \(\sigma^2_\rho = \frac{1}{n}0.05\), s.t. \(\operatorname{Var} (\log \rho_{n +1}) = 0.05\) and approx. \(\mathbf P(\log \rho_{n + 1} \in [-0.1, 0.1]) \geq 0.95\), so approx. \(\rho_{n +1} \in [0.9, 1.1]\), ensuring stabilitiy of infections counts (don’t go to \(0\) or \(\infty\)).

from pyprojroot import here
from isssm.laplace_approximation import posterior_mode
from isssm.laplace_approximation import posterior_mode
from isssm.importance_sampling import ess_pct
import pandas as pd
from isssm.importance_sampling import pgssm_importance_sampling
from isssm.ce_method import log_weight_cem, simulate_cem
from jax import vmap
from functools import partial
from isssm.laplace_approximation import laplace_approximation as LA
from isssm.modified_efficient_importance_sampling import (
    modified_efficient_importance_sampling as MEIS,
)
from isssm.ce_method import cross_entropy_method as CEM
from isssm.pgssm import simulate_pgssm
import jax.random as jrn
import jax.numpy as jnp
import jax
from isssm.typing import PGSSM
from tensorflow_probability.substrates.jax.distributions import Poisson

from tqdm.notebook import tqdm
jax.config.update("jax_enable_x64", True)
# parameters

N_samples = 10_000
N_ef = 1_000
N_iter = 100
M = 100
K = 10
K_ef = 100
parameters_tex = f"""
We set the number of iterations of the \\gls{{cem}} and \\gls{{eis}} to ${N_iter}$, which, in our experience, suffices to determine whether the numerical scheme converges or diverges. We use $M={M}$ samples to obtain the covariance matrices. For both methods we use $N = {N_samples}$ samples for estimation.  

The above procedure generates a single asymptotic variance ratios for a fixed number of time points $n$. As the performance of importance sampling is likely influenced by the sample $y$, we repeat the simulation $K={K}$ times to obtain $K$ different outcomes. 
"""

with open(
    here("chapters/03_state_space_models/03_08_comparison_ssm_parameters_var.tex"), "w"
) as f:
    f.write(parameters_tex)
text_parameters_repeat = f"""Again, we repeat this procedure $K={K_ef}$ times for varying levels of $n$ and use ${N_iter}$ iterations for all three methods, as well as ${N_samples}$ samples to estimate the optimal proposal."""

with open(
    here("chapters/03_state_space_models/03_08_comparison_ssm_parameters_ef.tex"), "w"
) as f:
    f.write(text_parameters_repeat)
The Kernel crashed while executing code in the current cell or a previous cell. 

Please review the code in the cell(s) to identify a possible cause of the failure. 

Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. 

View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.
def _model(n, I0):
    np1 = n + 1
    s2_rho = 0.05 / n if n > 1 else 1

    m = 2
    p = 1
    l = 1

    # states
    u = jnp.zeros((np1, m))
    u = u.at[0, 0].set(jnp.log(I0))

    A = jnp.broadcast_to(jnp.array([[1.0, 1.0], [0.0, 1.0]]), (n, m, m))
    D = jnp.broadcast_to(jnp.eye(m)[:, 1:2], (n, m, l))  # only update rho

    Sigma0 = jnp.array([[1.0, 0.0], [0.0, 0.1]])
    Sigma = jnp.broadcast_to(s2_rho * jnp.eye(1), (n, l, l))

    # observations
    B = jnp.broadcast_to(jnp.eye(m)[:1], (np1, p, m))

    v = jnp.zeros((np1, p))

    def poisson_obs(s, xi):
        return Poisson(log_rate=s)

    dist = poisson_obs

    xi = jnp.empty((np1, p, 1))
    return PGSSM(u, A, D, Sigma0, Sigma, v, B, dist, xi)
def determine_efficiency_factor(n, key):
    pgssm = _model(n, I0=1000)
    key, subkey = jrn.split(key)

    _, (Y,) = simulate_pgssm(pgssm, 1, subkey)

    key, sk_meis, sk_cem = jrn.split(key, 3)
    prop_la, _ = LA(Y, pgssm, N_iter)
    prop_meis, _ = MEIS(Y, pgssm, prop_la.z, prop_la.Omega, N_iter, N_samples, sk_meis)
    prop_cem, lw_cem = CEM(pgssm, Y, N_samples, sk_cem, N_iter)

    key, sk_la, sk_meis, sk_cem = jrn.split(key, 4)
    _, lw_la = pgssm_importance_sampling(
        Y, pgssm, prop_la.z, prop_la.Omega, N_ef, sk_la
    )
    _, lw_meis = pgssm_importance_sampling(
        Y, pgssm, prop_meis.z, prop_meis.Omega, N_ef, sk_meis
    )

    # lw_cem = vmap(partial(log_weight_cem, y=Y, model=pgssm, proposal=prop_cem))(
    #    simulate_cem(prop_cem, N_samples, sk_cem)
    # )

    result = pd.Series(
        {
            "n": n,
            "N_samples": N_samples,
            "N_iter": N_iter,
            "EF_LA": ess_pct(lw_la),
            "EF_MEIS": ess_pct(lw_meis),
            "EF_CEM": ess_pct(lw_cem),
        }
    )

    return result
key = jrn.PRNGKey(140235293)
ns_ef = jnp.repeat(jnp.array([1, 10, 20, 50, 100]), K_ef)
key, *keys_ef = jrn.split(key, len(ns_ef) + 1)
results_list = []

for n, k in tqdm(zip(ns_ef, keys_ef), total=len(ns_ef)):
    results_list.append(determine_efficiency_factor(n, k))

results_ef = pd.DataFrame(results_list)

results_ef.to_csv(here("data/figures/ef_meis_cem_ssms.csv"), index=False)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[54], line 4
      1 results_list = []
      3 for n, k in tqdm(zip(ns_ef, keys_ef), total=len(ns_ef)):
----> 4     results_list.append(determine_efficiency_factor(n, k))
      5 results_ef = pd.DataFrame(results_list)
      7 results_ef.to_csv(here("data/figures/ef_meis_cem_ssms.csv"), index=False)

Cell In[52], line 8, in determine_efficiency_factor(n, key)
      5 _, (Y,) = simulate_pgssm(pgssm, 1, subkey)
      7 key, sk_meis, sk_cem = jrn.split(key, 3)
----> 8 prop_la, _ = LA(Y, pgssm, N_iter)
      9 prop_meis, _ = MEIS(Y, pgssm, prop_la.z, prop_la.Omega, N_iter, N_samples, sk_meis)
     10 prop_cem, lw_cem = CEM(pgssm, Y, N_samples, sk_cem, N_iter)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/isssm/laplace_approximation.py:62, in laplace_approximation(y, model, n_iter, log_lik, d_log_lik, dd_log_lik, eps, link)
     59 u, A, D, Sigma0, Sigma, v, B, dist, xi = model
     60 np1, p, m = B.shape
---> 62 s_init = vvmap(partial(_initial_guess, dist=dist, link=link))(xi, y)
     64 def default_log_lik(s_ti, xi_ti, y_ti):
     65     return dist(s_ti, xi_ti).log_prob(y_ti).sum()

    [... skipping hidden 1 frame]

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/api.py:1127, in vmap.<locals>.vmap_f(*args, **kwargs)
   1124 try:
   1125   axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name,
   1126                                 explicit_mesh_axis)
-> 1127   out_flat = batching.batch(
   1128       flat_fun, axis_data, in_axes_flat,
   1129       lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
   1130   ).call_wrapped(*args_flat)
   1131 except batching.SpecMatchError as e:
   1132   out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py:211, in WrappedFun.call_wrapped(self, *args, **kwargs)
    209 def call_wrapped(self, *args, **kwargs):
    210   """Calls the transformed function"""
--> 211   return self.f_transformed(*args, **kwargs)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:609, in _batch_outer(f, axis_data, in_dims, *in_vals)
    607 tag = TraceTag()
    608 with source_info_util.transform_name_stack('vmap'):
--> 609   outs, trace = f(tag, in_dims, *in_vals)
    610 with core.ensure_no_leaks(trace): del trace
    611 return outs

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:625, in _batch_inner(f, axis_data, out_dim_dests, tag, in_dims, *in_vals)
    621   in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
    622 with (core.set_current_trace(trace),
    623       core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
    624       core.add_spmd_axis_names(axis_data.spmd_name)):
--> 625   outs = f(*in_tracers)
    626   out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
    627   out_vals = map(partial(from_elt, trace, axis_data.size, axis_data.explicit_mesh_axis),
    628                  range(len(outs)), outs, out_dim_dests)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:340, in flatten_fun_for_vmap(f, store, in_tree, *args_flat)
    336 @lu.transformation_with_aux2
    337 def flatten_fun_for_vmap(f: Callable,
    338                          store: lu.Store, in_tree: PyTreeDef, *args_flat):
    339   py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
--> 340   ans = f(*py_args, **py_kwargs)
    341   ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable)
    342   store.store(out_tree)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py:402, in _get_result_paths_thunk(_fun, _store, *args, **kwargs)
    400 @transformation_with_aux2
    401 def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
--> 402   ans = _fun(*args, **kwargs)
    403   result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans))
    404   if _store:
    405     # In some instances a lu.WrappedFun is called multiple times, e.g.,
    406     # the bwd function in a custom_vjp

    [... skipping hidden 1 frame]

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/api.py:1127, in vmap.<locals>.vmap_f(*args, **kwargs)
   1124 try:
   1125   axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name,
   1126                                 explicit_mesh_axis)
-> 1127   out_flat = batching.batch(
   1128       flat_fun, axis_data, in_axes_flat,
   1129       lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
   1130   ).call_wrapped(*args_flat)
   1131 except batching.SpecMatchError as e:
   1132   out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py:211, in WrappedFun.call_wrapped(self, *args, **kwargs)
    209 def call_wrapped(self, *args, **kwargs):
    210   """Calls the transformed function"""
--> 211   return self.f_transformed(*args, **kwargs)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:609, in _batch_outer(f, axis_data, in_dims, *in_vals)
    607 tag = TraceTag()
    608 with source_info_util.transform_name_stack('vmap'):
--> 609   outs, trace = f(tag, in_dims, *in_vals)
    610 with core.ensure_no_leaks(trace): del trace
    611 return outs

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:625, in _batch_inner(f, axis_data, out_dim_dests, tag, in_dims, *in_vals)
    621   in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
    622 with (core.set_current_trace(trace),
    623       core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
    624       core.add_spmd_axis_names(axis_data.spmd_name)):
--> 625   outs = f(*in_tracers)
    626   out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
    627   out_vals = map(partial(from_elt, trace, axis_data.size, axis_data.explicit_mesh_axis),
    628                  range(len(outs)), outs, out_dim_dests)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:340, in flatten_fun_for_vmap(f, store, in_tree, *args_flat)
    336 @lu.transformation_with_aux2
    337 def flatten_fun_for_vmap(f: Callable,
    338                          store: lu.Store, in_tree: PyTreeDef, *args_flat):
    339   py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
--> 340   ans = f(*py_args, **py_kwargs)
    341   ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable)
    342   store.store(out_tree)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py:402, in _get_result_paths_thunk(_fun, _store, *args, **kwargs)
    400 @transformation_with_aux2
    401 def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
--> 402   ans = _fun(*args, **kwargs)
    403   result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans))
    404   if _store:
    405     # In some instances a lu.WrappedFun is called multiple times, e.g.,
    406     # the bwd function in a custom_vjp

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/isssm/laplace_approximation.py:41, in _initial_guess(xi_ti, y_ti, dist, link)
     40 def _initial_guess(xi_ti, y_ti, dist, link=default_link):
---> 41     result = minimize(
     42         lambda s_ti: -dist(s_ti, xi_ti).log_prob(y_ti).sum(),
     43         jnp.atleast_1d(default_link(y_ti)),
     44         method="BFGS",
     45     )
     46     return jnp.squeeze(result.x)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/scipy/optimize/minimize.py:109, in minimize(fun, x0, args, method, tol, options)
    106 fun_with_args = lambda x: fun(x, *args)
    108 if method.lower() == 'bfgs':
--> 109   results = minimize_bfgs(fun_with_args, x0, **options)
    110   success = results.converged & jnp.logical_not(results.failed)
    111   return OptimizeResults(x=results.x_k,
    112                          success=success,
    113                          status=results.status,
   (...)
    118                          njev=results.ngev,
    119                          nit=results.k)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/scipy/optimize/bfgs.py:168, in minimize_bfgs(fun, x0, maxiter, norm, gtol, line_search_maxiter)
    157   state = state._replace(
    158       converged=converged,
    159       k=state.k + 1,
   (...)
    164       old_old_fval=state.f_k,
    165   )
    166   return state
--> 168 state = lax.while_loop(cond_fun, body_fun, state)
    169 status = jnp.where(
    170     state.converged,
    171     0,  # converged
   (...)
    180     )
    181 )
    182 state = state._replace(status=status)

    [... skipping hidden 1 frame]

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:1636, in while_loop(cond_fun, body_fun, init_val)
   1633   init_vals, new_body_consts = partition_list(move_to_const, init_vals)
   1634   body_consts = [*new_body_consts, *body_consts]
-> 1636 outs = while_p.bind(*cond_consts, *body_consts, *init_vals,
   1637                     cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
   1638                     body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
   1640 if any(move_to_const):
   1641   outs = pe.merge_lists(move_to_const, outs, new_body_consts)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:531, in Primitive.bind(self, *args, **params)
    529 def bind(self, *args, **params):
    530   args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 531   return self._true_bind(*args, **params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:551, in Primitive._true_bind(self, *args, **params)
    549 trace_ctx.set_trace(eval_trace)
    550 try:
--> 551   return self.bind_with_trace(prev_trace, args, params)
    552 finally:
    553   trace_ctx.set_trace(prev_trace)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:556, in Primitive.bind_with_trace(self, trace, args, params)
    555 def bind_with_trace(self, trace, args, params):
--> 556   return trace.process_primitive(self, args, params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:496, in BatchTrace.process_primitive(self, p, tracers, params)
    494   else:
    495     with core.set_current_trace(self.parent_trace):
--> 496       val_out, dim_out = fancy_primitive_batchers[p](
    497           self.axis_data, vals_in, dims_in, **params)
    498 elif args_not_mapped:
    499   # no-op shortcut
    500   return p.bind_with_trace(self.parent_trace, vals_in, params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:1769, in _while_loop_batching_rule(axis_data, args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr)
   1766     assert new_axis is not batching.not_mapped
   1767     new_init.append(batching.moveaxis(x, old_axis, new_axis))
-> 1769 outs = while_p.bind(*(cconsts + bconsts + new_init),
   1770                     cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
   1771                     body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
   1772 return outs, carry_dims

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:531, in Primitive.bind(self, *args, **params)
    529 def bind(self, *args, **params):
    530   args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 531   return self._true_bind(*args, **params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:551, in Primitive._true_bind(self, *args, **params)
    549 trace_ctx.set_trace(eval_trace)
    550 try:
--> 551   return self.bind_with_trace(prev_trace, args, params)
    552 finally:
    553   trace_ctx.set_trace(prev_trace)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:556, in Primitive.bind_with_trace(self, trace, args, params)
    555 def bind_with_trace(self, trace, args, params):
--> 556   return trace.process_primitive(self, args, params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:496, in BatchTrace.process_primitive(self, p, tracers, params)
    494   else:
    495     with core.set_current_trace(self.parent_trace):
--> 496       val_out, dim_out = fancy_primitive_batchers[p](
    497           self.axis_data, vals_in, dims_in, **params)
    498 elif args_not_mapped:
    499   # no-op shortcut
    500   return p.bind_with_trace(self.parent_trace, vals_in, params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:1769, in _while_loop_batching_rule(axis_data, args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr)
   1766     assert new_axis is not batching.not_mapped
   1767     new_init.append(batching.moveaxis(x, old_axis, new_axis))
-> 1769 outs = while_p.bind(*(cconsts + bconsts + new_init),
   1770                     cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
   1771                     body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
   1772 return outs, carry_dims

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:531, in Primitive.bind(self, *args, **params)
    529 def bind(self, *args, **params):
    530   args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 531   return self._true_bind(*args, **params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:551, in Primitive._true_bind(self, *args, **params)
    549 trace_ctx.set_trace(eval_trace)
    550 try:
--> 551   return self.bind_with_trace(prev_trace, args, params)
    552 finally:
    553   trace_ctx.set_trace(prev_trace)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:556, in Primitive.bind_with_trace(self, trace, args, params)
    555 def bind_with_trace(self, trace, args, params):
--> 556   return trace.process_primitive(self, args, params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:1060, in EvalTrace.process_primitive(self, primitive, args, params)
   1058 args = map(full_lower, args)
   1059 check_eval_args(args)
-> 1060 return primitive.impl(*args, **params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/dispatch.py:88, in apply_primitive(prim, *args, **params)
     86 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     87 try:
---> 88   outs = fun(*args)
     89 finally:
     90   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 1 frame]

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:334, in _cpp_pjit.<locals>.cache_miss(*args, **kwargs)
    329 if config.no_tracing.value:
    330   raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
    331                      "`jit`, but 'no_tracing' is set")
    333 (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, box_data,
--> 334  executable, pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
    336 maybe_fastpath_data = _get_fastpath_data(
    337     executable, out_tree, args_flat, out_flat, attrs_tracked, box_data,
    338     jaxpr.effects, jaxpr.consts, jit_info.abstracted_axes, pgle_profiler)
    340 return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:197, in _python_pjit_helper(fun, jit_info, *args, **kwargs)
    195   out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
    196 else:
--> 197   out_flat = pjit_p.bind(*args_flat, **p.params)
    198   compiled = None
    199   profiler = None

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:531, in Primitive.bind(self, *args, **params)
    529 def bind(self, *args, **params):
    530   args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 531   return self._true_bind(*args, **params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:551, in Primitive._true_bind(self, *args, **params)
    549 trace_ctx.set_trace(eval_trace)
    550 try:
--> 551   return self.bind_with_trace(prev_trace, args, params)
    552 finally:
    553   trace_ctx.set_trace(prev_trace)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:556, in Primitive.bind_with_trace(self, trace, args, params)
    555 def bind_with_trace(self, trace, args, params):
--> 556   return trace.process_primitive(self, args, params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:1060, in EvalTrace.process_primitive(self, primitive, args, params)
   1058 args = map(full_lower, args)
   1059 check_eval_args(args)
-> 1060 return primitive.impl(*args, **params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:1928, in _pjit_call_impl(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, *args)
   1920 donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
   1921 cache_key = pxla.JitGlobalCppCacheKeys(
   1922     donate_argnums=donated_argnums, donate_argnames=None,
   1923     device=None, backend=None,
   (...)
   1926     in_layouts_treedef=None, in_layouts_leaves=in_layouts,
   1927     out_layouts_treedef=None, out_layouts_leaves=out_layouts)
-> 1928 return xc._xla.pjit(
   1929     name, f, call_impl_cache_miss, [], [], cache_key,
   1930     tree_util.dispatch_registry, pxla.cc_shard_arg,
   1931     _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:1905, in _pjit_call_impl.<locals>.call_impl_cache_miss(*args_, **kwargs_)
   1904 def call_impl_cache_miss(*args_, **kwargs_):
-> 1905   out_flat, compiled, pgle_profiler = _pjit_call_impl_python(
   1906       *args, jaxpr=jaxpr, in_shardings=in_shardings,
   1907       out_shardings=out_shardings, in_layouts=in_layouts,
   1908       out_layouts=out_layouts, donated_invars=donated_invars,
   1909       ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused,
   1910       inline=inline, compiler_options_kvs=compiler_options_kvs)
   1911   fastpath_data = _get_fastpath_data(
   1912       compiled, tree_structure(out_flat), args, out_flat, [], [],
   1913       jaxpr.effects, jaxpr.consts, None, pgle_profiler)
   1914   return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:1862, in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, *args)
   1850 compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items())
   1851 # Passing mutable PGLE profile here since it should be extracted by JAXPR to
   1852 # initialize the fdo_profile compile option.
   1853 compiled = _resolve_and_lower(
   1854     args, jaxpr=jaxpr, in_shardings=in_shardings,
   1855     out_shardings=out_shardings, in_layouts=in_layouts,
   1856     out_layouts=out_layouts, donated_invars=donated_invars,
   1857     ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused,
   1858     inline=inline, lowering_platforms=None,
   1859     lowering_parameters=mlir.LoweringParameters(),
   1860     pgle_profiler=pgle_profiler,
   1861     compiler_options_kvs=compiler_options_kvs,
-> 1862 ).compile()
   1864 # This check is expensive so only do it if enable_checks is on.
   1865 if compiled._auto_spmd_lowering and config.enable_checks.value:

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:2467, in MeshComputation.compile(self, compiler_options)
   2465 compiler_options_kvs = self._compiler_options_kvs + t_compiler_options
   2466 if self._executable is None or compiler_options_kvs:
-> 2467   executable = UnloadedMeshExecutable.from_hlo(
   2468       self._name, self._hlo, **self.compile_args,
   2469       compiler_options_kvs=compiler_options_kvs)
   2470   if not compiler_options_kvs:
   2471     self._executable = executable

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:3009, in UnloadedMeshExecutable.from_hlo(***failed resolving arguments***)
   3006       break
   3008 util.test_event("pxla_cached_compilation")
-> 3009 xla_executable = _cached_compilation(
   3010     hlo, name, mesh, spmd_lowering,
   3011     tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
   3012     allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
   3013     compiler_options_kvs, pgle_profiler)
   3015 if auto_spmd_lowering:
   3016   assert mesh is not None

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:2800, in _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_kvs, pgle_profiler)
   2792 compile_options = create_compile_options(
   2793     computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering,
   2794     allow_prop_to_inputs, allow_prop_to_outputs, backend,
   2795     dev, pmap_nreps, compiler_options)
   2797 with dispatch.log_elapsed_time(
   2798     "Finished XLA compilation of {fun_name} in {elapsed_time:.9f} sec",
   2799     fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
-> 2800   xla_executable = compiler.compile_or_get_cached(
   2801       backend, computation, dev, compile_options, host_callbacks,
   2802       da, pgle_profiler)
   2803 return xla_executable

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/compiler.py:447, in compile_or_get_cached(backend, computation, devices, compile_options, host_callbacks, executable_devices, pgle_profiler)
    445 else:
    446   log_persistent_cache_miss(module_name, cache_key)
--> 447   return _compile_and_write_cache(
    448       backend,
    449       computation,
    450       executable_devices,
    451       compile_options,
    452       host_callbacks,
    453       module_name,
    454       cache_key,
    455   )

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/compiler.py:719, in _compile_and_write_cache(backend, computation, executable_devices, compile_options, host_callbacks, module_name, cache_key)
    709 def _compile_and_write_cache(
    710     backend: xc.Client,
    711     computation: ir.Module,
   (...)
    716     cache_key: str,
    717 ) -> xc.LoadedExecutable:
    718   start_time = time.monotonic()
--> 719   executable = backend_compile(
    720       backend, computation, executable_devices, compile_options, host_callbacks
    721   )
    722   compile_time = time.monotonic() - start_time
    723   _cache_write(
    724       cache_key, compile_time, module_name, backend, executable, host_callbacks
    725   )

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/profiler.py:354, in annotate_function.<locals>.wrapper(*args, **kwargs)
    351 @wraps(func)
    352 def wrapper(*args, **kwargs):
    353   with TraceAnnotation(name, **decorator_kwargs):
--> 354     return func(*args, **kwargs)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/compiler.py:335, in backend_compile(backend, module, executable_devices, options, host_callbacks)
    326     return backend.compile(
    327         built_c,
    328         executable_devices=executable_devices,  # type: ignore
    329         compile_options=options,
    330         host_callbacks=host_callbacks,
    331     )
    332   # Some backends don't have `host_callbacks` option yet
    333   # TODO(sharadmv): remove this fallback when all backends allow `compile`
    334   # to take in `host_callbacks`
--> 335   return backend.compile(
    336       built_c, executable_devices=executable_devices, compile_options=options)  # type: ignore
    337 except xc.XlaRuntimeError as e:
    338   for error_handler in _XLA_RUNTIME_ERROR_HANDLERS:

KeyboardInterrupt: 
def asymptotic_det_meis(Y, pgssm, prop_la, N_iter, N_samples, key, M: int):
    key, *subkeys = jrn.split(key, 1 + M)
    proposals = [
        MEIS(Y, pgssm, prop_la.z, prop_la.Omega, N_iter, N_samples, sk)[0]
        for sk in subkeys
    ]
    modes = jnp.array([posterior_mode(proposal).reshape(-1) for proposal in proposals])
    cov = jnp.cov(modes, rowvar=False) * N_samples
    _, logdet = jnp.linalg.slogdet(cov)

    return logdet


def asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M: int):
    key, *subkeys = jrn.split(key, 1 + M)
    proposals = [CEM(pgssm, Y, N_samples, sk_cem, N_iter)[0] for sk_cem in subkeys]
    modes = jnp.array([proposal.mean[:, 0] for proposal in proposals])
    cov = jnp.cov(modes, rowvar=False) * N_samples

    _, logdet = jnp.linalg.slogdet(cov)
    return logdet


def asymptotic_variance(n: int, key: jrn.PRNGKey):
    pgssm = _model(n, I0=1000)
    key, subkey = jrn.split(key)

    _, (Y,) = simulate_pgssm(pgssm, 1, subkey)

    prop_la, _ = LA(Y, pgssm, N_iter)

    key, *sks = jrn.split(key, 1 + 2 * M)

    sks_meis = sks[:M]
    sks_cem = sks[M:]

    logdet_cem = asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M=len(sks_cem))
    logdet_meis = asymptotic_det_meis(
        Y, pgssm, prop_la, N_iter, N_samples, key, M=len(sks_meis)
    )

    result = pd.Series(
        {
            "n": n,
            "N_samples": N_samples,
            "N_iter": N_iter,
            "log_DET_CEM": logdet_cem,
            "log_DET_MEIS": logdet_meis,
            "ARE": jnp.exp(logdet_cem - logdet_meis),
        }
    )

    return result
key = jrn.PRNGKey(140235293)
ns_are = jnp.repeat(jnp.array([1, 2, 3, 4, 5]), K)
key, *keys_are = jrn.split(key, len(ns_are) + 1)
are_meis_cem_ssm_path = here("data/figures/are_meis_cem_ssms.csv")
if not are_meis_cem_ssm_path.exists():
    results_are = pd.DataFrame(
        [
            asymptotic_variance(n, k)
            for n, k in tqdm(zip(ns_are, keys_are), total=len(ns_are))
        ]
    )

    results_are.to_csv(are_meis_cem_ssm_path, index=False)
    results_are
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2294, in _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, name_stack, arg_names, result_names)
   2293 try:
-> 2294   func_op = ctx.cached_primitive_lowerings[key]
   2295 except KeyError:

KeyError: (None, let norm = { lambda ; a:f64[1000,4]. let
    b:f64[1000,4] = mul a a
    c:f64[1000] = reduce_sum[axes=(1,)] b
    d:f64[1000] = sqrt c
  in (d,) } in
let tril = { lambda ; e:f64[2,2]. let
    f:i32[2,2] = iota[dimension=0 dtype=int32 shape=(2, 2) sharding=None] 
    g:i32[2,2] = add f 0:i32[]
    h:i32[2,2] = iota[dimension=1 dtype=int32 shape=(2, 2) sharding=None] 
    i:bool[2,2] = ge g h
    j:f64[2,2] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2, 2)
      sharding=None
    ] 0.0:f64[]
    k:f64[2,2] = select_n i j e
  in (k,) } in
let diagonal = { lambda ; l:f64[2,2]. let
    m:i64[2,2] = iota[dimension=0 dtype=int64 shape=(2, 2) sharding=None] 
    n:i64[2,2] = iota[dimension=1 dtype=int64 shape=(2, 2) sharding=None] 
    o:i64[2,2] = add m 0:i64[]
    p:bool[2,2] = eq o n
    q:f64[2,2] = convert_element_type[new_dtype=float64 weak_type=False] p
    r:i32[] = platform_index[platforms=(('mosaic',), None)] 
    s:f64[2] = cond[
      branches=(
        { lambda ; t:f64[2,2] u:f64[2,2]. let
            v:f64[2,2] = mul t u
            w:f64[2] = reduce_sum[axes=(0,)] v
          in (w,) }
        { lambda ; x:f64[2,2] y:f64[2,2]. let
            z:i64[2] = iota[dimension=0 dtype=int64 shape=(2,) sharding=None] 
            ba:i64[2] = iota[dimension=0 dtype=int64 shape=(2,) sharding=None] 
            bb:bool[2] = lt z 0:i64[]
            bc:i64[2] = add z 2:i64[]
            bd:i64[2] = select_n bb z bc
            be:bool[2] = lt ba 0:i64[]
            bf:i64[2] = add ba 2:i64[]
            bg:i64[2] = select_n be ba bf
            bh:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] bd
            bi:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] bg
            bj:i32[2,1] = broadcast_in_dim[
              broadcast_dimensions=(0,)
              shape=(2, 1)
              sharding=None
            ] bh
            bk:i32[2,1] = broadcast_in_dim[
              broadcast_dimensions=(0,)
              shape=(2, 1)
              sharding=None
            ] bi
            bl:i32[2,2] = concatenate[dimension=1] bj bk
            bm:f64[2] = gather[
              dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
              fill_value=None
              indices_are_sorted=False
              mode=GatherScatterMode.PROMISE_IN_BOUNDS
              slice_sizes=(1, 1)
              unique_indices=False
            ] y bl
          in (bm,) }
      )
      branches_platforms=(('mosaic',), None)
    ] r q l
  in (s,) } in
let _where = { lambda ; bn:bool[1000] bo:f64[1000] bp:f64[1000]. let
    bq:f64[1000] = select_n bn bp bo
  in (bq,) } in
let polyval = { lambda ; br:f64[4] bs:f64[1000]. let
    bt:f64[1000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1000,)
      sharding=None
    ] 0.0:f64[]
    bu:f64[1000] = scan[
      _split_transpose=False
      jaxpr={ lambda ; bv:f64[1000] bw:f64[1000] bx:f64[]. let
          by:f64[1000] = mul bw bv
          bz:f64[1000] = add by bx
        in (bz,) }
      length=4
      linear=(False, False, False)
      num_carry=1
      num_consts=1
      reverse=False
      unroll=16
    ] bs bt br
  in (bu,) } in
let polyval1 = { lambda ; ca:f64[5] cb:f64[1000]. let
    cc:f64[1000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1000,)
      sharding=None
    ] 0.0:f64[]
    cd:f64[1000] = scan[
      _split_transpose=False
      jaxpr={ lambda ; ce:f64[1000] cf:f64[1000] cg:f64[]. let
          ch:f64[1000] = mul cf ce
          ci:f64[1000] = add ch cg
        in (ci,) }
      length=5
      linear=(False, False, False)
      num_carry=1
      num_consts=1
      reverse=False
      unroll=16
    ] cb cc ca
  in (cd,) } in
let _where1 = { lambda ; cj:bool[1] ck:f64[1000] cl:f64[1000]. let
    cm:bool[1000] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(1000,)
      sharding=None
    ] cj
    cn:f64[1000] = select_n cm cl ck
  in (cn,) } in
let isinf = { lambda ; co:f64[1000]. let
    cp:f64[1000] = abs co
    cq:bool[1000] = eq cp inf:f64[]
  in (cq,) } in
let _where2 = { lambda ; cr:bool[1000] cs:f64[] ct:f64[1000]. let
    cu:f64[1000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1000,)
      sharding=None
    ] cs
    cv:f64[1000] = select_n cr ct cu
  in (cv,) } in
let jaxpr = { lambda ; t:f64[2,2] u:f64[2,2]. let
    v:f64[2,2] = mul t u
    w:f64[2] = reduce_sum[axes=(0,)] v
  in (w,) } in
let jaxpr1 = { lambda ; x:f64[2,2] y:f64[2,2]. let
    z:i64[2] = iota[dimension=0 dtype=int64 shape=(2,) sharding=None] 
    ba:i64[2] = iota[dimension=0 dtype=int64 shape=(2,) sharding=None] 
    bb:bool[2] = lt z 0:i64[]
    bc:i64[2] = add z 2:i64[]
    bd:i64[2] = select_n bb z bc
    be:bool[2] = lt ba 0:i64[]
    bf:i64[2] = add ba 2:i64[]
    bg:i64[2] = select_n be ba bf
    bh:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] bd
    bi:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] bg
    bj:i32[2,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1)
      sharding=None
    ] bh
    bk:i32[2,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1)
      sharding=None
    ] bi
    bl:i32[2,2] = concatenate[dimension=1] bj bk
    bm:f64[2] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1), operand_batching_dims=(), start_indices_batching_dims=())
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] y bl
  in (bm,) } in
let atleast_2d = { lambda ; cw:f64[2,2]. let  in (cw,) } in
let _where3 = { lambda ; cx:bool[1] cy:f64[1000] cz:f64[1]. let
    da:bool[1000] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(1000,)
      sharding=None
    ] cx
    db:f64[1000] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(1000,)
      sharding=None
    ] cz
    dc:f64[1000] = select_n da db cy
  in (dc,) } in
let jaxpr2 = { lambda ; bv:f64[1000] bw:f64[1000] bx:f64[]. let
    by:f64[1000] = mul bw bv
    bz:f64[1000] = add by bx
  in (bz,) } in
let jaxpr3 = { lambda ; ce:f64[1000] cf:f64[1000] cg:f64[]. let
    ch:f64[1000] = mul cf ce
    ci:f64[1000] = add ch cg
  in (ci,) } in
let _where4 = { lambda ; bn:bool[1000] bo:f64[1000] bp:f64[1000]. let
    bq:f64[1000] = select_n bn bp bo
  in (bq,) } in
{ lambda ; dd:u32[2] de:f64[4] df:f64[5] dg:f64[4] dh:f64[5] di:f64[2,2] dj:f64[2,2]
    dk:f64[1,2,1] dl:f64[1,2,2] dm:f64[1,1,1] dn:f64[2,1,2] do:f64[2,1] dp:f64[2,1]
    dq:i64[] dr:f64[2,2] ds:f64[2,2,2] dt:f64[1,2,2] du:f64[1,2,2] dv:f64[4000]. let
    dw:i64[] = add dq 1:i64[]
    dx:key<fry>[] = random_wrap[impl=fry] dd
    dy:key<fry>[2] = random_split[shape=(2,)] dx
    dz:u32[2,2] = random_unwrap dy
    ea:u32[1,2] = slice[
      limit_indices=(2, 2)
      start_indices=(1, 0)
      strides=(1, 1)
    ] dz
    eb:u32[2] = squeeze[dimensions=(0,)] ea
    ec:f64[2] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2,)
      sharding=None
    ] 0.0:f64[]
    ed:i64[2,2] = iota[dimension=0 dtype=int64 shape=(2, 2) sharding=None] 
    ee:i64[2,2] = iota[dimension=1 dtype=int64 shape=(2, 2) sharding=None] 
    ef:i64[2,2] = add ed 0:i64[]
    eg:bool[2,2] = eq ef ee
    eh:f64[2,2] = convert_element_type[new_dtype=float64 weak_type=False] eg
    ei:f64[2,2] = pjit[
      name=cholesky
      jaxpr={ lambda ; eh:f64[2,2]. let
          ej:f64[2,2] = transpose[permutation=(1, 0)] eh
          ek:f64[2,2] = add eh ej
          el:f64[2,2] = div ek 2.0:f64[]
          em:f64[2,2] = cholesky el
          en:i32[2,2] = iota[dimension=0 dtype=int32 shape=(2, 2) sharding=None] 
          eo:i32[2,2] = add en 0:i32[]
          ep:i32[2,2] = iota[dimension=1 dtype=int32 shape=(2, 2) sharding=None] 
          eq:bool[2,2] = ge eo ep
          er:f64[2,2] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(2, 2)
            sharding=None
          ] 0.0:f64[]
          ei:f64[2,2] = select_n eq er em
        in (ei,) }
    ] eh
    es:key<fry>[] = random_wrap[impl=fry] eb
    et:f64[4000] = pjit[
      name=_normal
      jaxpr={ lambda ; es:key<fry>[]. let
          et:f64[4000] = pjit[
            name=_normal_real
            jaxpr={ lambda ; es:key<fry>[]. let
                eu:f64[4000] = pjit[
                  name=_uniform
                  jaxpr={ lambda ; es:key<fry>[] ev:f64[] ew:f64[]. let
                      ex:f64[1] = broadcast_in_dim[
                        broadcast_dimensions=()
                        shape=(1,)
                        sharding=None
                      ] ev
                      ey:f64[1] = broadcast_in_dim[
                        broadcast_dimensions=()
                        shape=(1,)
                        sharding=None
                      ] ew
                      ez:u64[4000] = random_bits[bit_width=64 shape=(4000,)] es
                      fa:u64[4000] = shift_right_logical ez 12:u64[]
                      fb:u64[4000] = or fa 4607182418800017408:u64[]
                      fc:f64[4000] = bitcast_convert_type[new_dtype=float64] fb
                      fd:f64[4000] = sub fc 1.0:f64[]
                      fe:f64[1] = sub ey ex
                      ff:f64[4000] = mul fd fe
                      fg:f64[4000] = add ff ex
                      eu:f64[4000] = max ex fg
                    in (eu,) }
                ] es -0.9999999999999999:f64[] 1.0:f64[]
                fh:f64[4000] = erf_inv eu
                et:f64[4000] = mul 1.4142135623730951:f64[] fh
              in (et,) }
          ] es
        in (et,) }
    ] es
    fi:f64[4000] = mul et 1.0:f64[]
    fj:f64[4000] = add fi 0.0:f64[]
    fk:f64[4000] = mul fj 1.0:f64[]
    fl:f64[4000] = add fk 0.0:f64[]
    fm:f64[2000,2] = reshape[dimensions=None new_sizes=(2000, 2) sharding=None] fl
    fn:f64[1000,2,2] = reshape[
      dimensions=None
      new_sizes=(1000, 2, 2)
      sharding=None
    ] fm
    fo:f64[1000,2,2,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1, 2)
      shape=(1000, 2, 2, 1)
      sharding=None
    ] fn
    fp:f64[2,2] = pjit[
      name=tril
      jaxpr={ lambda ; ei:f64[2,2]. let
          fq:i32[2,2] = iota[dimension=0 dtype=int32 shape=(2, 2) sharding=None] 
          fr:i32[2,2] = add fq 0:i32[]
          fs:i32[2,2] = iota[dimension=1 dtype=int32 shape=(2, 2) sharding=None] 
          ft:bool[2,2] = ge fr fs
          fu:f64[2,2] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(2, 2)
            sharding=None
          ] 0.0:f64[]
          fp:f64[2,2] = select_n ft fu ei
        in (fp,) }
    ] ei
    fv:f64[2,1000,2,1] = dot_general[
      dimension_numbers=(([1], [2]), ([], []))
      preferred_element_type=float64
    ] fp fo
    fw:f64[1000,2,2,1] = transpose[permutation=(1, 2, 0, 3)] fv
    fx:f64[1000,2,2] = squeeze[dimensions=(3,)] fw
    fy:f64[1,1,2] = broadcast_in_dim[
      broadcast_dimensions=(2,)
      shape=(1, 1, 2)
      sharding=None
    ] ec
    fz:f64[1000,2,2] = add fx fy
    ga:f64[2,2,1000] = dot_general[
      dimension_numbers=(([2], [2]), ([0], [1]))
      preferred_element_type=float64
    ] ds fz
    gb:f64[1000,2,2] = transpose[permutation=(2, 0, 1)] ga
    gc:f64[2,1000,2] = transpose[permutation=(1, 0, 2)] gb
    gd:i64[2,2] = iota[dimension=0 dtype=int64 shape=(2, 2) sharding=None] 
    ge:i64[2,2] = iota[dimension=1 dtype=int64 shape=(2, 2) sharding=None] 
    gf:i64[2,2] = add gd 0:i64[]
    gg:bool[2,2] = eq gf ge
    gh:f64[2,2] = convert_element_type[new_dtype=float64 weak_type=False] gg
    gi:f64[1,2,2] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 2, 2)
      sharding=None
    ] gh
    gj:f64[2,2,2] = concatenate[dimension=0] gi dt
    gk:i64[2,2] = iota[dimension=0 dtype=int64 shape=(2, 2) sharding=None] 
    gl:i64[2,2] = iota[dimension=1 dtype=int64 shape=(2, 2) sharding=None] 
    gm:i64[2,2] = add gk 0:i64[]
    gn:bool[2,2] = eq gm gl
    go:f64[2,2] = convert_element_type[new_dtype=float64 weak_type=False] gn
    gp:f64[1,2,2] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 2, 2)
      sharding=None
    ] go
    gq:f64[2,2,2] = concatenate[dimension=0] gp du
    gr:f64[1000,2] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1000, 2)
      sharding=None
    ] 0.0:f64[]
    _:f64[1000,2] gs:f64[2,1000,2] = scan[
      _split_transpose=False
      jaxpr={ lambda ; gt:f64[1000,2] gu:f64[1000,2] gv:f64[2,2] gw:f64[2,2]. let
          gx:f64[1000,2] = pjit[
            name=_solve_triangular
            jaxpr={ lambda ; gv:f64[2,2] gt:f64[1000,2]. let
                gy:f64[1000,2,1] = broadcast_in_dim[
                  broadcast_dimensions=(0, 1)
                  shape=(1000, 2, 1)
                  sharding=None
                ] gt
                gz:f64[2,1,1000] = transpose[permutation=(1, 2, 0)] gy
                ha:f64[2,1000] = reshape[
                  dimensions=None
                  new_sizes=(2, 1000)
                  sharding=None
                ] gz
                hb:f64[2,1000] = triangular_solve[
                  conjugate_a=False
                  left_side=True
                  lower=True
                  transpose_a=False
                  unit_diagonal=False
                ] gv ha
                hc:f64[2,1,1000] = reshape[
                  dimensions=None
                  new_sizes=(2, 1, 1000)
                  sharding=None
                ] hb
                hd:f64[2,1,1000] = slice[
                  limit_indices=(2, 1, 1000)
                  start_indices=(0, 0, 0)
                  strides=None
                ] hc
                he:f64[1000,2,1] = transpose[permutation=(2, 0, 1)] hd
                gx:f64[1000,2] = squeeze[dimensions=(2,)] he
              in (gx,) }
          ] gv gt
          hf:f64[2,1000] = dot_general[
            dimension_numbers=(([1], [1]), ([], []))
            preferred_element_type=float64
          ] gw gx
          hg:f64[1000,2] = transpose[permutation=(1, 0)] hf
          hh:f64[1000,2] = add hg gu
        in (hh, hh) }
      length=2
      linear=(False, False, False, False)
      num_carry=1
      num_consts=0
      reverse=False
      unroll=1
    ] gr gc gj gq
    hi:f64[1000,2,2] = transpose[permutation=(1, 0, 2)] gs
    hj:f64[1,2,2] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 2, 2)
      sharding=None
    ] dr
    hk:f64[1000,2,2] = add hi hj
    hl:f64[1000,4] = reshape[dimensions=None new_sizes=(1000, 4) sharding=None] fz
    hm:f64[1,2,2] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 2, 2)
      sharding=None
    ] dr
    hn:f64[1,2,2] = mul 2.0:f64[] hm
    ho:f64[1000,2,2] = sub hn hk
    hp:f64[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] 1.0:f64[]
    hq:f64[1] = mul 4.0:f64[] hp
    hr:f64[1000] = pjit[name=norm jaxpr=norm] hl
    hs:f64[1000] = integer_pow[y=2] hr
    ht:f64[1] = mul 0.5:f64[] hq
    hu:f64[1000] = mul 0.5:f64[] hs
    hv:f64[1000] = igamma ht hu
    hw:f64[1000] = sub 1.0:f64[] hv
    hx:f64[1] = mul 0.5:f64[] hq
    hy:f64[1000] = custom_jvp_call[
      name=_igammainv_custom_gradient
      call_jaxpr={ lambda ; hz:f64[4] ia:f64[5] ib:f64[1] ic:f64[1000]. let
          id:f64[1000] = sub 1.0:f64[] ic
          ie:f64[1] = lgamma ib
          if:f64[1000] = neg ic
          ig:f64[1000] = log1p if
          ih:f64[1000] = add ig ie
          ii:f64[1000] = neg ih
          ij:f64[1] = sub ib 1.0:f64[]
          ik:f64[1000] = custom_jvp_call[
            name=xlogy
            call_jaxpr={ lambda ; il:f64[1] im:f64[1000]. let
                in:bool[1] = ne il 0.0:f64[]
                io:f64[1000] = log im
                ip:f64[1000] = mul il io
                iq:f64[1] = broadcast_in_dim[
                  broadcast_dimensions=()
                  shape=(1,)
                  sharding=None
                ] 0.0:f64[]
                ir:f64[1000] = pjit[name=_where jaxpr=_where3] in ip iq
              in (ir,) }
            jvp=_xlogy_jvp
            symbolic_zeros=False
          ] ij ii
          is:f64[1000] = square ik
          it:f64[1000] = mul is ik
          iu:f64[1000] = square is
          iv:f64[1] = square ib
          iw:f64[1] = mul iv ib
          ix:f64[1] = sub ib 1.0:f64[]
          iy:f64[1000] = add 1.0:f64[] ik
          iz:f64[1000] = mul ix iy
          ja:f64[1] = sub ib 1.0:f64[]
          jb:f64[1] = mul 3.0:f64[] ib
          jc:f64[1] = sub jb 5.0:f64[]
          jd:f64[1] = div jc 2.0:f64[]
          je:f64[1] = sub ib 2.0:f64[]
          jf:f64[1000] = div ik 2.0:f64[]
          jg:f64[1000] = sub je jf
          jh:f64[1000] = mul ik jg
          ji:f64[1000] = add jd jh
          jj:f64[1000] = mul ja ji
          jk:f64[1] = sub ib 1.0:f64[]
          jl:f64[1000] = div it 3.0:f64[]
          jm:f64[1] = mul 3.0:f64[] ib
          jn:f64[1] = sub jm 5.0:f64[]
          jo:f64[1000] = mul jn is
          jp:f64[1000] = div jo 2.0:f64[]
          jq:f64[1000] = sub jl jp
          jr:f64[1] = mul 6.0:f64[] ib
          js:f64[1] = sub iv jr
          jt:f64[1] = add js 7.0:f64[]
          ju:f64[1000] = mul jt ik
          jv:f64[1000] = add jq ju
          jw:f64[1] = mul 11.0:f64[] iv
          jx:f64[1] = mul 46.0:f64[] ib
          jy:f64[1] = sub jw jx
          jz:f64[1] = add jy 47.0:f64[]
          ka:f64[1] = div jz 6.0:f64[]
          kb:f64[1000] = add jv ka
          kc:f64[1000] = mul jk kb
          kd:f64[1] = sub ib 1.0:f64[]
          ke:f64[1000] = neg iu
          kf:f64[1000] = div ke 4.0:f64[]
          kg:f64[1] = mul 11.0:f64[] ib
          kh:f64[1] = sub kg 17.0:f64[]
          ki:f64[1000] = mul kh it
          kj:f64[1000] = div ki 6.0:f64[]
          kk:f64[1000] = add kf kj
          kl:f64[1] = mul -3.0:f64[] iv
          km:f64[1] = mul 13.0:f64[] ib
          kn:f64[1] = add kl km
          ko:f64[1] = sub kn 13.0:f64[]
          kp:f64[1000] = mul ko is
          kq:f64[1000] = add kk kp
          kr:f64[1] = mul 2.0:f64[] iw
          ks:f64[1] = mul 25.0:f64[] iv
          kt:f64[1] = sub kr ks
          ku:f64[1] = mul 72.0:f64[] ib
          kv:f64[1] = add kt ku
          kw:f64[1] = sub kv 61.0:f64[]
          kx:f64[1000] = mul kw ik
          ky:f64[1000] = div kx 2.0:f64[]
          kz:f64[1000] = add kq ky
          la:f64[1] = mul 25.0:f64[] iw
          lb:f64[1] = mul 195.0:f64[] iv
          lc:f64[1] = sub la lb
          ld:f64[1] = mul 477.0:f64[] ib
          le:f64[1] = add lc ld
          lf:f64[1] = sub le 379.0:f64[]
          lg:f64[1] = div lf 12.0:f64[]
          lh:f64[1000] = add kz lg
          li:f64[1000] = mul kd lh
          lj:f64[1000] = add ii ik
          lk:f64[1000] = div li ii
          ll:f64[1000] = add lk kc
          lm:f64[1000] = div ll ii
          ln:f64[1000] = div jj ii
          lo:f64[1000] = add lm ln
          lp:f64[1000] = add lo iz
          lq:f64[1000] = div lp ii
          lr:f64[1000] = add lj lq
          ls:f64[1000] = neg ih
          lt:f64[1] = sub 1.0:f64[] ib
          lu:f64[1000] = neg ih
          lv:f64[1000] = log lu
          lw:f64[1000] = mul lt lv
          lx:f64[1000] = sub ls lw
          ly:f64[1000] = square lx
          lz:bool[1000] = gt ih -4.605170185988091:f64[]
          ma:f64[1000] = neg ih
          mb:f64[1] = sub 1.0:f64[] ib
          mc:f64[1000] = log lx
          md:f64[1000] = mul mb mc
          me:f64[1000] = sub ma md
          mf:f64[1] = sub 3.0:f64[] ib
          mg:f64[1] = mul 2.0:f64[] mf
          mh:f64[1000] = mul mg lx
          mi:f64[1000] = add ly mh
          mj:f64[1] = sub 2.0:f64[] ib
          mk:f64[1] = sub 3.0:f64[] ib
          ml:f64[1] = mul mj mk
          mm:f64[1000] = add mi ml
          mn:f64[1] = sub 5.0:f64[] ib
          mo:f64[1000] = mul mn lx
          mp:f64[1000] = add ly mo
          mq:f64[1000] = add mp 2.0:f64[]
          mr:f64[1000] = div mm mq
          ms:f64[1000] = log mr
          mt:f64[1000] = sub me ms
          mu:f64[1000] = pjit[name=_where jaxpr=_where] lz mt lr
          mv:bool[1000] = ge ih -1.8971199848858813:f64[]
          mw:f64[1000] = neg ih
          mx:f64[1] = sub ib 1.0:f64[]
          my:f64[1000] = custom_jvp_call[
            name=xlogy
            call_jaxpr={ lambda ; mz:f64[1] na:f64[1000]. let
                nb:bool[1] = ne mz 0.0:f64[]
                nc:f64[1000] = log na
                nd:f64[1000] = mul mz nc
                ne:f64[1] = broadcast_in_dim[
                  broadcast_dimensions=()
                  shape=(1,)
                  sharding=None
                ] 0.0:f64[]
                nf:f64[1000] = pjit[name=_where jaxpr=_where3] nb nd ne
              in (nf,) }
            jvp=_xlogy_jvp
            symbolic_zeros=False
          ] mx lx
          ng:f64[1000] = add mw my
          nh:f64[1] = sub 1.0:f64[] ib
          ni:f64[1000] = add 1.0:f64[] lx
          nj:f64[1000] = div nh ni
          nk:f64[1000] = log1p nj
          nl:f64[1000] = sub ng nk
          nm:f64[1000] = pjit[name=_where jaxpr=_where] mv nl mu
          nn:f64[1000] = exp ih
          no:f64[1000] = sub -0.5772156649015329:f64[] nn
          np:f64[1000] = exp no
          nq:f64[1000] = exp np
          nr:f64[1000] = mul np nq
          ns:bool[1] = lt ib 0.3:f64[]
          nt:bool[1000] = ge ih -1.0498221244986778:f64[]
          nu:bool[1000] = and ns nt
          nv:f64[1000] = exp nr
          nw:f64[1000] = mul np nv
          nx:f64[1000] = pjit[name=_where jaxpr=_where] nu nw nm
          ny:f64[1000] = exp ih
          nz:f64[1000] = mul ny id
          oa:bool[1000] = gt nz 1e-08:f64[]
          ob:bool[1000] = gt id 1e-05:f64[]
          oc:bool[1000] = and oa ob
          od:f64[1] = exp ie
          oe:f64[1000] = mul ic od
          of:f64[1000] = mul oe ib
          og:f64[1] = integer_pow[y=-1] ib
          oh:f64[1000] = pow of og
          oi:f64[1000] = neg id
          oj:f64[1000] = div oi ib
          ok:f64[1000] = sub oj 0.5772156649015329:f64[]
          ol:f64[1000] = exp ok
          om:f64[1000] = pjit[name=_where jaxpr=_where] oc oh ol
          on:bool[1000] = gt ih -0.5108256237659907:f64[]
          oo:bool[1000] = ge ih -0.7985076962177716:f64[]
          op:bool[1] = ge ib 0.3:f64[]
          oq:bool[1000] = and oo op
          or:bool[1000] = or on oq
          os:f64[1] = add ib 1.0:f64[]
          ot:f64[1000] = div om os
          ou:f64[1000] = sub 1.0:f64[] ot
          ov:f64[1000] = div om ou
          ow:f64[1000] = pjit[name=_where jaxpr=_where] or ov nx
          ox:f64[1] = sqrt ib
          oy:bool[1000] = lt ic 0.5:f64[]
          oz:f64[1000] = log ic
          pa:f64[1000] = mul -2.0:f64[] oz
          pb:f64[1000] = sqrt pa
          pc:f64[1000] = log id
          pd:f64[1000] = mul -2.0:f64[] pc
          pe:f64[1000] = sqrt pd
          pf:f64[1000] = pjit[name=_where jaxpr=_where] oy pb pe
          pg:f64[1000] = pjit[name=polyval jaxpr=polyval] hz pf
          ph:f64[1000] = pjit[name=polyval jaxpr=polyval1] ia pf
          pi:f64[1000] = div pg ph
          pj:f64[1000] = sub pf pi
          pk:bool[1000] = lt ic 0.5:f64[]
          pl:f64[1000] = neg pj
          pm:f64[1000] = pjit[name=_where jaxpr=_where] pk pl pj
          pn:f64[1000] = square pm
          po:f64[1000] = mul pn pm
          pp:f64[1000] = square pn
          pq:f64[1000] = mul pp pm
          pr:f64[1000] = mul pm ox
          ps:f64[1000] = add ib pr
          pt:f64[1000] = sub pn 1.0:f64[]
          pu:f64[1000] = div pt 3.0:f64[]
          pv:f64[1000] = add ps pu
          pw:f64[1000] = mul 7.0:f64[] pm
          px:f64[1000] = sub po pw
          py:f64[1] = mul 36.0:f64[] ox
          pz:f64[1000] = div px py
          qa:f64[1000] = add pv pz
          qb:f64[1000] = mul 3.0:f64[] pp
          qc:f64[1000] = mul 7.0:f64[] pn
          qd:f64[1000] = add qb qc
          qe:f64[1000] = sub qd 16.0:f64[]
          qf:f64[1] = mul 810.0:f64[] ib
          qg:f64[1000] = div qe qf
          qh:f64[1000] = sub qa qg
          qi:f64[1000] = mul 9.0:f64[] pq
          qj:f64[1000] = mul 256.0:f64[] po
          qk:f64[1000] = add qi qj
          ql:f64[1000] = mul 433.0:f64[] pm
          qm:f64[1000] = sub qk ql
          qn:f64[1] = mul 38880.0:f64[] ib
          qo:f64[1] = mul qn ox
          qp:f64[1000] = div qm qo
          qq:f64[1000] = add qh qp
          qr:f64[1] = sub ib 1.0:f64[]
          qs:f64[1] = mul ib qr
          qt:f64[1] = copy qs
          qu:f64[1] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(1,)
            sharding=None
          ] 2.0:f64[]
          qv:f64[1] = custom_jvp_call[
            name=_maximum_
            call_jaxpr={ lambda ; qw:f64[1] qx:f64[1]. let
                qy:f64[1] = max qw qx
              in (qy,) }
            jvp=_maximum_jvp
            symbolic_zeros=False
          ] qu qt
          qz:f64[1] = neg qv
          ra:f64[1] = mul qz 2.302585092994046:f64[]
          rb:bool[1000] = le ih ra
          rc:f64[1000] = neg ih
          rd:f64[1] = sub ib 1.0:f64[]
          re:f64[1000] = custom_jvp_call[
            name=xlogy
            call_jaxpr={ lambda ; rf:f64[1] rg:f64[1000]. let
                rh:bool[1] = ne rf 0.0:f64[]
                ri:f64[1000] = log rg
                rj:f64[1000] = mul rf ri
                rk:f64[1] = broadcast_in_dim[
                  broadcast_dimensions=()
                  shape=(1,)
                  sharding=None
                ] 0.0:f64[]
                rl:f64[1000] = pjit[name=_where jaxpr=_where3] rh rj rk
              in (rl,) }
            jvp=_xlogy_jvp
            symbolic_zeros=False
          ] rd rc
          rm:f64[1000] = square re
          rn:f64[1000] = mul rm re
          ro:f64[1000] = square rm
          rp:f64[1] = square ib
          rq:f64[1] = mul rp ib
          rr:f64[1] = sub ib 1.0:f64[]
          rs:f64[1000] = add 1.0:f64[] re
          rt:f64[1000] = mul rr rs
          ru:f64[1] = sub ib 1.0:f64[]
          rv:f64[1] = mul 3.0:f64[] ib
          rw:f64[1] = sub rv 5.0:f64[]
          rx:f64[1] = div rw 2.0:f64[]
          ry:f64[1] = sub ib 2.0:f64[]
          rz:f64[1000] = div re 2.0:f64[]
          sa:f64[1000] = sub ry rz
          sb:f64[1000] = mul re sa
          sc:f64[1000] = add rx sb
          sd:f64[1000] = mul ru sc
          se:f64[1] = sub ib 1.0:f64[]
          sf:f64[1000] = div rn 3.0:f64[]
          sg:f64[1] = mul 3.0:f64[] ib
          sh:f64[1] = sub sg 5.0:f64[]
          si:f64[1000] = mul sh rm
          sj:f64[1000] = div si 2.0:f64[]
          sk:f64[1000] = sub sf sj
          sl:f64[1] = mul 6.0:f64[] ib
          sm:f64[1] = sub rp sl
          sn:f64[1] = add sm 7.0:f64[]
          so:f64[1000] = mul sn re
          sp:f64[1000] = add sk so
          sq:f64[1] = mul 11.0:f64[] rp
          sr:f64[1] = mul 46.0:f64[] ib
          ss:f64[1] = sub sq sr
          st:f64[1] = add ss 47.0:f64[]
          su:f64[1] = div st 6.0:f64[]
          sv:f64[1000] = add sp su
          sw:f64[1000] = mul se sv
          sx:f64[1] = sub ib 1.0:f64[]
          sy:f64[1000] = neg ro
          sz:f64[1000] = div sy 4.0:f64[]
          ta:f64[1] = mul 11.0:f64[] ib
          tb:f64[1] = sub ta 17.0:f64[]
          tc:f64[1000] = mul tb rn
          td:f64[1000] = div tc 6.0:f64[]
          te:f64[1000] = add sz td
          tf:f64[1] = mul -3.0:f64[] rp
          tg:f64[1] = mul 13.0:f64[] ib
          th:f64[1] = add tf tg
          ti:f64[1] = sub th 13.0:f64[]
          tj:f64[1000] = mul ti rm
          tk:f64[1000] = add te tj
          tl:f64[1] = mul 2.0:f64[] rq
          tm:f64[1] = mul 25.0:f64[] rp
          tn:f64[1] = sub tl tm
          to:f64[1] = mul 72.0:f64[] ib
          tp:f64[1] = add tn to
          tq:f64[1] = sub tp 61.0:f64[]
          tr:f64[1000] = mul tq re
          ts:f64[1000] = div tr 2.0:f64[]
          tt:f64[1000] = add tk ts
          tu:f64[1] = mul 25.0:f64[] rq
          tv:f64[1] = mul 195.0:f64[] rp
          tw:f64[1] = sub tu tv
          tx:f64[1] = mul 477.0:f64[] ib
          ty:f64[1] = add tw tx
          tz:f64[1] = sub ty 379.0:f64[]
          ua:f64[1] = div tz 12.0:f64[]
          ub:f64[1000] = add tt ua
          uc:f64[1000] = mul sx ub
          ud:f64[1000] = add rc re
          ue:f64[1000] = div uc rc
          uf:f64[1000] = add ue sw
          ug:f64[1000] = div uf rc
          uh:f64[1000] = div sd rc
          ui:f64[1000] = add ug uh
          uj:f64[1000] = add ui rt
          uk:f64[1000] = div uj rc
          ul:f64[1000] = add ud uk
          um:f64[1000] = neg ih
          un:f64[1] = sub ib 1.0:f64[]
          uo:f64[1000] = custom_jvp_call[
            name=xlogy
            call_jaxpr={ lambda ; up:f64[1] uq:f64[1000]. let
                ur:bool[1] = ne up 0.0:f64[]
                us:f64[1000] = log uq
                ut:f64[1000] = mul up us
                uu:f64[1] = broadcast_in_dim[
                  broadcast_dimensions=()
                  shape=(1,)
                  sharding=None
                ] 0.0:f64[]
                uv:f64[1000] = pjit[name=_where jaxpr=_where3] ur ut uu
              in (uv,) }
            jvp=_xlogy_jvp
            symbolic_zeros=False
          ] un qq
          uw:f64[1000] = add um uo
          ux:f64[1] = sub 1.0:f64[] ib
          uy:f64[1000] = add 1.0:f64[] qq
          uz:f64[1000] = div ux uy
          va:f64[1000] = log1p uz
          vb:f64[1000] = sub uw va
          vc:f64[1000] = neg ih
          vd:f64[1] = sub ib 1.0:f64[]
          ve:f64[1000] = custom_jvp_call[
            name=xlogy
            call_jaxpr={ lambda ; vf:f64[1] vg:f64[1000]. let
                vh:bool[1] = ne vf 0.0:f64[]
                vi:f64[1000] = log vg
                vj:f64[1000] = mul vf vi
                vk:f64[1] = broadcast_in_dim[
                  broadcast_dimensions=()
                  shape=(1,)
                  sharding=None
                ] 0.0:f64[]
                vl:f64[1000] = pjit[name=_where jaxpr=_where3] vh vj vk
              in (vl,) }
            jvp=_xlogy_jvp
            symbolic_zeros=False
          ] vd vb
          vm:f64[1000] = add vc ve
          vn:f64[1] = sub 1.0:f64[] ib
          vo:f64[1000] = add 1.0:f64[] vb
          vp:f64[1000] = div vn vo
          vq:f64[1000] = log1p vp
          vr:f64[1000] = sub vm vq
          vs:f64[1000] = pjit[name=_where jaxpr=_where] rb ul vr
          vt:f64[1] = mul 3.0:f64[] ib
          vu:bool[1000] = lt qq vt
          vv:f64[1000] = pjit[name=_where jaxpr=_where] vu qq vs
          vw:bool[1] = ge ib 500.0:f64[]
          vx:f64[1000] = div qq ib
          vy:f64[1000] = sub 1.0:f64[] vx
          vz:f64[1000] = abs vy
          wa:bool[1000] = lt vz 1e-06:f64[]
          wb:bool[1000] = and vw wa
          wc:f64[1000] = pjit[name=_where jaxpr=_where] wb qq vv
          wd:f64[1000] = log ic
          we:f64[1] = add ib 1.0:f64[]
          wf:f64[1] = lgamma we
          wg:f64[1000] = add wd wf
          wh:f64[1000] = add wg qq
          wi:f64[1000] = div wh ib
          wj:f64[1000] = exp wi
          wk:f64[1] = add ib 1.0:f64[]
          wl:f64[1000] = div wj wk
          wm:f64[1] = add ib 2.0:f64[]
          wn:f64[1000] = div wj wm
          wo:f64[1000] = add 1.0:f64[] wn
          wp:f64[1000] = mul wl wo
          wq:f64[1000] = log1p wp
          wr:f64[1000] = add wg wj
          ws:f64[1000] = sub wr wq
          wt:f64[1000] = div ws ib
          wu:f64[1000] = exp wt
          wv:f64[1] = add ib 1.0:f64[]
          ww:f64[1000] = div wu wv
          wx:f64[1] = add ib 2.0:f64[]
          wy:f64[1000] = div wu wx
          wz:f64[1000] = add 1.0:f64[] wy
          xa:f64[1000] = mul ww wz
          xb:f64[1000] = log1p xa
          xc:f64[1000] = add wg wu
          xd:f64[1000] = sub xc xb
          xe:f64[1000] = div xd ib
          xf:f64[1000] = exp xe
          xg:f64[1] = add ib 1.0:f64[]
          xh:f64[1000] = div xf xg
          xi:f64[1] = add ib 2.0:f64[]
          xj:f64[1000] = div xf xi
          xk:f64[1] = add ib 3.0:f64[]
          xl:f64[1000] = div xf xk
          xm:f64[1000] = add 1.0:f64[] xl
          xn:f64[1000] = mul xj xm
          xo:f64[1000] = add 1.0:f64[] xn
          xp:f64[1000] = mul xh xo
          xq:f64[1000] = log1p xp
          xr:f64[1000] = add wg xf
          xs:f64[1000] = sub xr xq
          xt:f64[1000] = div xs ib
          xu:f64[1000] = exp xt
          xv:f64[1] = add ib 1.0:f64[]
          xw:f64[1] = mul 0.15:f64[] xv
          xx:bool[1000] = le qq xw
          xy:f64[1000] = pjit[name=_where jaxpr=_where] xx xu qq
          xz:bool[1000] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(1000,)
            sharding=None
          ] False:bool[]
          ya:f64[1000] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(1000,)
            sharding=None
          ] 1.0:f64[]
          yb:f64[1000] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(1000,)
            sharding=None
          ] 1.0:f64[]
          _:bool[1000] _:f64[] _:f64[1000] yc:f64[1000] = while[
            body_jaxpr={ lambda ; yd:f64[1000] ye:f64[1] yf:bool[1000] yg:f64[] yh:f64[1000]
                yi:f64[1000]. let
                yj:f64[1000] = mul yh yd
                yk:f64[1] = add ye yg
                yl:f64[1000] = div yj yk
                ym:f64[1000] = add yi yl
                yn:f64[1000] = pjit[name=_where jaxpr=_where4] yf yi ym
                yo:bool[1000] = lt yl 0.0001:f64[]
                yp:bool[] = gt yg 100.0:f64[]
                yq:bool[1000] = or yo yp
                yr:f64[] = add yg 1.0:f64[]
              in (yq, yr, yl, yn) }
            body_nconsts=2
            cond_jaxpr={ lambda ; ys:bool[1000] yt:f64[] yu:f64[1000] yv:f64[1000]. let
                yw:bool[1000] = not ys
                yx:bool[] = reduce_or[axes=(0,)] yw
              in (yx,) }
            cond_nconsts=0
          ] xy ib xz 1.0:f64[] ya yb
          yy:f64[1000] = log yc
          yz:f64[1000] = add wg xy
          za:f64[1000] = sub yz yy
          zb:f64[1000] = div za ib
          zc:f64[1000] = exp zb
          zd:f64[1] = add ib 1.0:f64[]
          ze:f64[1] = mul 0.01:f64[] zd
          zf:bool[1000] = le xy ze
          zg:f64[1] = add ib 1.0:f64[]
          zh:f64[1] = mul 0.7:f64[] zg
          zi:bool[1000] = gt xy zh
          zj:bool[1000] = or zf zi
          zk:f64[1000] = log zc
          zl:f64[1000] = mul ib zk
          zm:f64[1000] = sub zl zc
          zn:f64[1000] = sub zm wg
          zo:f64[1000] = add zn yy
          zp:f64[1000] = sub ib zc
          zq:f64[1000] = div zo zp
          zr:f64[1000] = sub 1.0:f64[] zq
          zs:f64[1000] = mul zc zr
          zt:f64[1000] = pjit[name=_where jaxpr=_where] zj xy zs
          zu:bool[1000] = le ic 0.5:f64[]
          zv:f64[1000] = pjit[name=_where jaxpr=_where] zu zt wc
          zw:bool[1] = lt ib 1.0:f64[]
          zx:f64[1000] = pjit[name=_where jaxpr=_where1] zw ow zv
          zy:bool[1] = eq ib 1.0:f64[]
          zz:f64[1000] = neg ig
          baa:f64[1000] = pjit[name=_where jaxpr=_where1] zy zz zx
          bab:f64[1000] = log baa
          bac:f64[1000] = mul ib bab
          bad:f64[1000] = sub bac baa
          bae:f64[1] = lgamma ib
          baf:f64[1000] = sub bad bae
          bag:f64[1000] = exp baf
          bah:bool[1000] = le ic 0.9:f64[]
          bai:bool[1000] = and bah True:bool[]
          baj:bool[1000] = gt id 0.9:f64[]
          bak:bool[1000] = and baj False:bool[]
          bal:bool[1000] = or bai bak
          bam:f64[1000] = igamma ib baa
          ban:f64[1000] = sub bam ic
          bao:f64[1000] = mul ban baa
          bap:f64[1000] = div bao bag
          baq:f64[1000] = igammac ib baa
          bar:f64[1000] = sub baq id
          bas:f64[1000] = neg bar
          bat:f64[1000] = mul bas baa
          bau:f64[1000] = div bat bag
          bav:f64[1000] = pjit[name=_where jaxpr=_where] bal bap bau
          baw:f64[1] = sub ib 1.0:f64[]
          bax:f64[1000] = div baw baa
          bay:f64[1000] = add -1.0:f64[] bax
          baz:bool[1000] = pjit[name=isinf jaxpr=isinf] bay
          bba:f64[1000] = sub baa bav
          bbb:f64[1000] = mul 0.5:f64[] bav
          bbc:f64[1000] = mul bbb bay
          bbd:f64[1000] = sub 1.0:f64[] bbc
          bbe:f64[1000] = div bav bbd
          bbf:f64[1000] = sub baa bbe
          bbg:f64[1000] = pjit[name=_where jaxpr=_where] baz bba bbf
          bbh:bool[1000] = eq bag 0.0:f64[]
          bbi:f64[1000] = pjit[name=_where jaxpr=_where] bbh baa bbg
          bbj:f64[1000] = log bbi
          bbk:f64[1000] = mul ib bbj
          bbl:f64[1000] = sub bbk bbi
          bbm:f64[1] = lgamma ib
          bbn:f64[1000] = sub bbl bbm
          bbo:f64[1000] = exp bbn
          bbp:bool[1000] = le ic 0.9:f64[]
          bbq:bool[1000] = and bbp True:bool[]
          bbr:bool[1000] = gt id 0.9:f64[]
          bbs:bool[1000] = and bbr False:bool[]
          bbt:bool[1000] = or bbq bbs
          bbu:f64[1000] = igamma ib bbi
          bbv:f64[1000] = sub bbu ic
          bbw:f64[1000] = mul bbv bbi
          bbx:f64[1000] = div bbw bbo
          bby:f64[1000] = igammac ib bbi
          bbz:f64[1000] = sub bby id
          bca:f64[1000] = neg bbz
          bcb:f64[1000] = mul bca bbi
          bcc:f64[1000] = div bcb bbo
          bcd:f64[1000] = pjit[name=_where jaxpr=_where] bbt bbx bcc
          bce:f64[1] = sub ib 1.0:f64[]
          bcf:f64[1000] = div bce bbi
          bcg:f64[1000] = add -1.0:f64[] bcf
          bch:bool[1000] = pjit[name=isinf jaxpr=isinf] bcg
          bci:f64[1000] = sub bbi bcd
          bcj:f64[1000] = mul 0.5:f64[] bcd
          bck:f64[1000] = mul bcj bcg
          bcl:f64[1000] = sub 1.0:f64[] bck
          bcm:f64[1000] = div bcd bcl
          bcn:f64[1000] = sub bbi bcm
          bco:f64[1000] = pjit[name=_where jaxpr=_where] bch bci bcn
          bcp:bool[1000] = eq bbo 0.0:f64[]
          bcq:f64[1000] = pjit[name=_where jaxpr=_where] bcp bbi bco
          bcr:f64[1000] = log bcq
          bcs:f64[1000] = mul ib bcr
          bct:f64[1000] = sub bcs bcq
          bcu:f64[1] = lgamma ib
          bcv:f64[1000] = sub bct bcu
          bcw:f64[1000] = exp bcv
          bcx:bool[1000] = le ic 0.9:f64[]
          bcy:bool[1000] = and bcx True:bool[]
          bcz:bool[1000] = gt id 0.9:f64[]
          bda:bool[1000] = and bcz False:bool[]
          bdb:bool[1000] = or bcy bda
          bdc:f64[1000] = igamma ib bcq
          bdd:f64[1000] = sub bdc ic
          bde:f64[1000] = mul bdd bcq
          bdf:f64[1000] = div bde bcw
          bdg:f64[1000] = igammac ib bcq
          bdh:f64[1000] = sub bdg id
          bdi:f64[1000] = neg bdh
          bdj:f64[1000] = mul bdi bcq
          bdk:f64[1000] = div bdj bcw
          bdl:f64[1000] = pjit[name=_where jaxpr=_where] bdb bdf bdk
          bdm:f64[1] = sub ib 1.0:f64[]
          bdn:f64[1000] = div bdm bcq
          bdo:f64[1000] = add -1.0:f64[] bdn
          bdp:bool[1000] = pjit[name=isinf jaxpr=isinf] bdo
          bdq:f64[1000] = sub bcq bdl
          bdr:f64[1000] = mul 0.5:f64[] bdl
          bds:f64[1000] = mul bdr bdo
          bdt:f64[1000] = sub 1.0:f64[] bds
          bdu:f64[1000] = div bdl bdt
          bdv:f64[1000] = sub bcq bdu
          bdw:f64[1000] = pjit[name=_where jaxpr=_where] bdp bdq bdv
          bdx:bool[1000] = eq bcw 0.0:f64[]
          bdy:f64[1000] = pjit[name=_where jaxpr=_where] bdx bcq bdw
          bdz:bool[1] = lt ib 0.0:f64[]
          bea:bool[1000] = lt ic 0.0:f64[]
          beb:bool[1000] = or bdz bea
          bec:bool[1000] = gt ic 1.0:f64[]
          bed:bool[1000] = or beb bec
          bee:f64[1000] = pjit[name=_where jaxpr=_where2] bed nan:f64[] bdy
          bef:bool[1000] = eq ic 0.0:f64[]
          beg:f64[1000] = pjit[name=_where jaxpr=_where2] bef 0.0:f64[] bee
          beh:bool[1000] = eq ic 1.0:f64[]
          bei:f64[1000] = pjit[name=_where jaxpr=_where2] beh inf:f64[] beg
        in (bei,) }
      jvp=_igammainv_jvp
      num_consts=2
      symbolic_zeros=False
    ] de df hx hw
    bej:f64[1000] = mul 2.0:f64[] hy
    bek:f64[1,2,2] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 2, 2)
      sharding=None
    ] dr
    bel:f64[1000] = div bej hs
    bem:f64[1000] = sqrt bel
    ben:f64[1000,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(1000, 1, 1)
      sharding=None
    ] bem
    beo:f64[1,2,2] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 2, 2)
      sharding=None
    ] dr
    bep:f64[1000,2,2] = sub hk beo
    beq:f64[1000,2,2] = mul ben bep
    ber:f64[1000,2,2] = add bek beq
    bes:f64[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] 1.0:f64[]
    bet:f64[1] = mul 4.0:f64[] bes
    beu:f64[1000] = pjit[name=norm jaxpr=norm] hl
    bev:f64[1000] = integer_pow[y=2] beu
    bew:f64[1] = mul 0.5:f64[] bet
    bex:f64[1000] = mul 0.5:f64[] bev
    bey:f64[1000] = igamma bew bex
    bez:f64[1000] = sub 1.0:f64[] bey
    bfa:f64[1] = mul 0.5:f64[] bet
    bfb:f64[1000] = custom_jvp_call[
      name=_igammainv_custom_gradient
      call_jaxpr={ lambda ; bfc:f64[4] bfd:f64[5] bfe:f64[1] bff:f64[1000]. let
          bfg:f64[1000] = sub 1.0:f64[] bff
          bfh:f64[1] = lgamma bfe
          bfi:f64[1000] = neg bff
          bfj:f64[1000] = log1p bfi
          bfk:f64[1000] = add bfj bfh
          bfl:f64[1000] = neg bfk
          bfm:f64[1] = sub bfe 1.0:f64[]
          bfn:f64[1000] = custom_jvp_call[
            name=xlogy
            call_jaxpr={ lambda ; bfo:f64[1] bfp:f64[1000]. let
                bfq:bool[1] = ne bfo 0.0:f64[]
                bfr:f64[1000] = log bfp
                bfs:f64[1000] = mul bfo bfr
                bft:f64[1] = broadcast_in_dim[
                  broadcast_dimensions=()
                  shape=(1,)
                  sharding=None
                ] 0.0:f64[]
                bfu:f64[1000] = pjit[name=_where jaxpr=_where3] bfq bfs bft
              in (bfu,) }
            jvp=_xlogy_jvp
            symbolic_zeros=False
          ] bfm bfl
          bfv:f64[1000] = square bfn
          bfw:f64[1000] = mul bfv bfn
          bfx:f64[1000] = square bfv
          bfy:f64[1] = square bfe
          bfz:f64[1] = mul bfy bfe
          bga:f64[1] = sub bfe 1.0:f64[]
          bgb:f64[1000] = add 1.0:f64[] bfn
          bgc:f64[1000] = mul bga bgb
          bgd:f64[1] = sub bfe 1.0:f64[]
          bge:f64[1] = mul 3.0:f64[] bfe
          bgf:f64[1] = sub bge 5.0:f64[]
          bgg:f64[1] = div bgf 2.0:f64[]
          bgh:f64[1] = sub bfe 2.0:f64[]
          bgi:f64[1000] = div bfn 2.0:f64[]
          bgj:f64[1000] = sub bgh bgi
          bgk:f64[1000] = mul bfn bgj
          bgl:f64[1000] = add bgg bgk
          bgm:f64[1000] = mul bgd bgl
          bgn:f64[1] = sub bfe 1.0:f64[]
          bgo:f64[1000] = div bfw 3.0:f64[]
          bgp:f64[1] = mul 3.0:f64[] bfe
          bgq:f64[1] = sub bgp 5.0:f64[]
          bgr:f64[1000] = mul bgq bfv
          bgs:f64[1000] = div bgr 2.0:f64[]
          bgt:f64[1000] = sub bgo bgs
          bgu:f64[1] = mul 6.0:f64[] bfe
          bgv:f64[1] = sub bfy bgu
          bgw:f64[1] = add bgv 7.0:f64[]
          bgx:f64[1000] = mul bgw bfn
          bgy:f64[1000] = add bgt bgx
          bgz:f64[1] = mul 11.0:f64[] bfy
          bha:f64[1] = mul 46.0:f64[] bfe
          bhb:f64[1] = sub bgz bha
          bhc:f64[1] = add bhb 47.0:f64[]
          bhd:f64[1] = div bhc 6.0:f64[]
          bhe:f64[1000] = add bgy bhd
          bhf:f64[1000] = mul bgn bhe
          bhg:f64[1] = sub bfe 1.0:f64[]
          bhh:f64[1000] = neg bfx
          bhi:f64[1000] = div bhh 4.0:f64[]
          bhj:f64[1] = mul 11.0:f64[] bfe
          bhk:f64[1] = sub bhj 17.0:f64[]
          bhl:f64[1000] = mul bhk bfw
          bhm:f64[1000] = div bhl 6.0:f64[]
          bhn:f64[1000] = add bhi bhm
          bho:f64[1] = mul -3.0:f64[] bfy
          bhp:f64[1] = mul 13.0:f64[] bfe
          bhq:f64[1] = add bho bhp
          bhr:f64[1] = sub bhq 13.0:f64[]
          bhs:f64[1000] = mul bhr bfv
          bht:f64[1000] = add bhn bhs
          bhu:f64[1] = mul 2.0:f64[] bfz
          bhv:f64[1] = mul 25.0:f64[] bfy
          bhw:f64[1] = sub bhu bhv
          bhx:f64[1] = mul 72.0:f64[] bfe
          bhy:f64[1] = add bhw bhx
          bhz:f64[1] = sub bhy 61.0:f64[]
          bia:f64[1000] = mul bhz bfn
          bib:f64[1000] = div bia 2.0:f64[]
          bic:f64[1000] = add bht bib
          bid:f64[1] = mul 25.0:f64[] bfz
          bie:f64[1] = mul 195.0:f64[] bfy
          bif:f64[1] = sub bid bie
          big:f64[1] = mul 477.0:f64[] bfe
          bih:f64[1] = add bif big
          bii:f64[1] = sub bih 379.0:f64[]
          bij:f64[1] = div bii 12.0:f64[]
          bik:f64[1000] = add bic bij
          bil:f64[1000] = mul bhg bik
          bim:f64[1000] = add bfl bfn
          bin:f64[1000] = div bil bfl
          bio:f64[1000] = add bin bhf
          bip:f64[1000] = div bio bfl
          biq:f64[1000] = div bgm bfl
          bir:f64[1000] = add bip biq
          bis:f64[1000] = add bir bgc
          bit:f64[1000] = div bis bfl
          biu:f64[1000] = add bim bit
          biv:f64[1000] = neg bfk
          biw:f64[1] = sub 1.0:f64[] bfe
          bix:f64[1000] = neg bfk
          biy:f64[1000] = log bix
          biz:f64[1000] = mul biw biy
          bja:f64[1000] = sub biv biz
          bjb:f64[1000] = square bja
          bjc:bool[1000] = gt bfk -4.605170185988091:f64[]
          bjd:f64[1000] = neg bfk
          bje:f64[1] = sub 1.0:f64[] bfe
          bjf:f64[1000] = log bja
          bjg:f64[1000] = mul bje bjf
          bjh:f64[1000] = sub bjd bjg
          bji:f64[1] = sub 3.0:f64[] bfe
          bjj:f64[1] = mul 2.0:f64[] bji
          bjk:f64[1000] = mul bjj bja
          bjl:f64[1000] = add bjb bjk
          bjm:f64[1] = sub 2.0:f64[] bfe
          bjn:f64[1] = sub 3.0:f64[] bfe
          bjo:f64[1] = mul bjm bjn
          bjp:f64[1000] = add bjl bjo
          bjq:f64[1] = sub 5.0:f64[] bfe
          bjr:f64[1000] = mul bjq bja
          bjs:f64[1000] = add bjb bjr
          bjt:f64[1000] = add bjs 2.0:f64[]
          bju:f64[1000] = div bjp bjt
          bjv:f64[1000] = log bju
          bjw:f64[1000] = sub bjh bjv
          bjx:f64[1000] = pjit[name=_where jaxpr=_where] bjc bjw biu
          bjy:bool[1000] = ge bfk -1.8971199848858813:f64[]
          bjz:f64[1000] = neg bfk
          bka:f64[1] = sub bfe 1.0:f64[]
          bkb:f64[1000] = custom_jvp_call[
            name=xlogy
            call_jaxpr={ lambda ; bkc:f64[1] bkd:f64[1000]. let
                bke:bool[1] = ne bkc 0.0:f64[]
                bkf:f64[1000] = log bkd
                bkg:f64[1000] = mul bkc bkf
                bkh:f64[1] = broadcast_in_dim[
                  broadcast_dimensions=()
                  shape=(1,)
                  sharding=None
                ] 0.0:f64[]
                bki:f64[1000] = pjit[name=_where jaxpr=_where3] bke bkg bkh
              in (bki,) }
            jvp=_xlogy_jvp
            symbolic_zeros=False
          ] bka bja
          bkj:f64[1000] = add bjz bkb
          bkk:f64[1] = sub 1.0:f64[] bfe
          bkl:f64[1000] = add 1.0:f64[] bja
          bkm:f64[1000] = div bkk bkl
          bkn:f64[1000] = log1p bkm
          bko:f64[1000] = sub bkj bkn
          bkp:f64[1000] = pjit[name=_where jaxpr=_where] bjy bko bjx
          bkq:f64[1000] = exp bfk
          bkr:f64[1000] = sub -0.5772156649015329:f64[] bkq
          bks:f64[1000] = exp bkr
          bkt:f64[1000] = exp bks
          bku:f64[1000] = mul bks bkt
          bkv:bool[1] = lt bfe 0.3:f64[]
          bkw:bool[1000] = ge bfk -1.0498221244986778:f64[]
          bkx:bool[1000] = and bkv bkw
          bky:f64[1000] = exp bku
          bkz:f64[1000] = mul bks bky
          bla:f64[1000] = pjit[name=_where jaxpr=_where] bkx bkz bkp
          blb:f64[1000] = exp bfk
          blc:f64[1000] = mul blb bfg
          bld:bool[1000] = gt blc 1e-08:f64[]
          ble:bool[1000] = gt bfg 1e-05:f64[]
          blf:bool[1000] = and bld ble
          blg:f64[1] = exp bfh
          blh:f64[1000] = mul bff blg
          bli:f64[1000] = mul blh bfe
          blj:f64[1] = integer_pow[y=-1] bfe
          blk:f64[1000] = pow bli blj
          bll:f64[1000] = neg bfg
          blm:f64[1000] = div bll bfe
          bln:f64[1000] = sub blm 0.5772156649015329:f64[]
          blo:f64[1000] = exp bln
          blp:f64[1000] = pjit[name=_where jaxpr=_where] blf blk blo
          blq:bool[1000] = gt bfk -0.5108256237659907:f64[]
          blr:bool[1000] = ge bfk -0.7985076962177716:f64[]
          bls:bool[1] = ge bfe 0.3:f64[]
          blt:bool[1000] = and blr bls
          blu:bool[1000] = or blq blt
          blv:f64[1] = add bfe 1.0:f64[]
          blw:f64[1000] = div blp blv
          blx:f64[1000] = sub 1.0:f64[] blw
          bly:f64[1000] = div blp blx
          blz:f64[1000] = pjit[name=_where jaxpr=_where] blu bly bla
          bma:f64[1] = sqrt bfe
          bmb:bool[1000] = lt bff 0.5:f64[]
          bmc:f64[1000] = log bff
          bmd:f64[1000] = mul -2.0:f64[] bmc
          bme:f64[1000] = sqrt bmd
          bmf:f64[1000] = log bfg
          bmg:f64[1000] = mul -2.0:f64[] bmf
          bmh:f64[1000] = sqrt bmg
          bmi:f64[1000] = pjit[name=_where jaxpr=_where] bmb bme bmh
          bmj:f64[1000] = pjit[name=polyval jaxpr=polyval] bfc bmi
          bmk:f64[1000] = pjit[name=polyval jaxpr=polyval1] bfd bmi
          bml:f64[1000] = div bmj bmk
          bmm:f64[1000] = sub bmi bml
          bmn:bool[1000] = lt bff 0.5:f64[]
          bmo:f64[1000] = neg bmm
          bmp:f64[1000] = pjit[name=_where jaxpr=_where] bmn bmo bmm
          bmq:f64[1000] = square bmp
          bmr:f64[1000] = mul bmq bmp
          bms:f64[1000] = square bmq
          bmt:f64[1000] = mul bms bmp
          bmu:f64[1000] = mul bmp bma
          bmv:f64[1000] = add bfe bmu
          bmw:f64[1000] = sub bmq 1.0:f64[]
          bmx:f64[1000] = div bmw 3.0:f64[]
          bmy:f64[1000] = add bmv bmx
          bmz:f64[1000] = mul 7.0:f64[] bmp
          bna:f64[1000] = sub bmr bmz
          bnb:f64[1] = mul 36.0:f64[] bma
          bnc:f64[1000] = div bna bnb
          bnd:f64[1000] = add bmy bnc
          bne:f64[1000] = mul 3.0:f64[] bms
          bnf:f64[1000] = mul 7.0:f64[] bmq
          bng:f64[1000] = add bne bnf
          bnh:f64[1000] = sub bng 16.0:f64[]
          bni:f64[1] = mul 810.0:f64[] bfe
          bnj:f64[1000] = div bnh bni
          bnk:f64[1000] = sub bnd bnj
          bnl:f64[1000] = mul 9.0:f64[] bmt
          bnm:f64[1000] = mul 256.0:f64[] bmr
          bnn:f64[1000] = add bnl bnm
          bno:f64[1000] = mul 433.0:f64[] bmp
          bnp:f64[1000] = sub bnn bno
          bnq:f64[1] = mul 38880.0:f64[] bfe
          bnr:f64[1] = mul bnq bma
          bns:f64[1000] = div bnp bnr
          bnt:f64[1000] = add bnk bns
          bnu:f64[1] = sub bfe 1.0:f64[]
          bnv:f64[1] = mul bfe bnu
          bnw:f64[1] = copy bnv
          bnx:f64[1] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(1,)
            sharding=None
          ] 2.0:f64[]
          bny:f64[1] = custom_jvp_call[
            name=_maximum_
            call_jaxpr={ lambda ; bnz:f64[1] boa:f64[1]. let
                bob:f64[1] = max bnz boa
              in (bob,) }
            jvp=_maximum_jvp
            symbolic_zeros=False
          ] bnx bnw
          boc:f64[1] = neg bny
          bod:f64[1] = mul boc 2.302585092994046:f64[]
          boe:bool[1000] = le bfk bod
          bof:f64[1000] = neg bfk
          bog:f64[1] = sub bfe 1.0:f64[]
          boh:f64[1000] = custom_jvp_call[
            name=xlogy
            call_jaxpr={ lambda ; boi:f64[1] boj:f64[1000]. let
                bok:bool[1] = ne boi 0.0:f64[]
                bol:f64[1000] = log boj
                bom:f64[1000] = mul boi bol
                bon:f64[1] = broadcast_in_dim[
                  broadcast_dimensions=()
                  shape=(1,)
                  sharding=None
                ] 0.0:f64[]
                boo:f64[1000] = pjit[name=_where jaxpr=_where3] bok bom bon
              in (boo,) }
            jvp=_xlogy_jvp
            symbolic_zeros=False
          ] bog bof
          bop:f64[1000] = square boh
          boq:f64[1000] = mul bop boh
          bor:f64[1000] = square bop
          bos:f64[1] = square bfe
          bot:f64[1] = mul bos bfe
          bou:f64[1] = sub bfe 1.0:f64[]
          bov:f64[1000] = add 1.0:f64[] boh
          bow:f64[1000] = mul bou bov
          box:f64[1] = sub bfe 1.0:f64[]
          boy:f64[1] = mul 3.0:f64[] bfe
          boz:f64[1] = sub boy 5.0:f64[]
          bpa:f64[1] = div boz 2.0:f64[]
          bpb:f64[1] = sub bfe 2.0:f64[]
          bpc:f64[1000] = div boh 2.0:f64[]
          bpd:f64[1000] = sub bpb bpc
          bpe:f64[1000] = mul boh bpd
          bpf:f64[1000] = add bpa bpe
          bpg:f64[1000] = mul box bpf
          bph:f64[1] = sub bfe 1.0:f64[]
          bpi:f64[1000] = div boq 3.0:f64[]
          bpj:f64[1] = mul 3.0:f64[] bfe
          bpk:f64[1] = sub bpj 5.0:f64[]
          bpl:f64[1000] = mul bpk bop
          bpm:f64[1000] = div bpl 2.0:f64[]
          bpn:f64[1000] = sub bpi bpm
          bpo:f64[1] = mul 6.0:f64[] bfe
          bpp:f64[1] = sub bos bpo
          bpq:f64[1] = add bpp 7.0:f64[]
          bpr:f64[1000] = mul bpq boh
          bps:f64[1000] = add bpn bpr
          bpt:f64[1] = mul 11.0:f64[] bos
          bpu:f64[1] = mul 46.0:f64[] bfe
          bpv:f64[1] = sub bpt bpu
          bpw:f64[1] = add bpv 47.0:f64[]
          bpx:f64[1] = div bpw 6.0:f64[]
          bpy:f64[1000] = add bps bpx
          bpz:f64[1000] = mul bph bpy
          bqa:f64[1] = sub bfe 1.0:f64[]
          bqb:f64[1000] = neg bor
          bqc:f64[1000] = div bqb 4.0:f64[]
          bqd:f64[1] = mul 11.0:f64[] bfe
          bqe:f64[1] = sub bqd 17.0:f64[]
          bqf:f64[1000] = mul bqe boq
          bqg:f64[1000] = div bqf 6.0:f64[]
          bqh:f64[1000] = add bqc bqg
          bqi:f64[1] = mul -3.0:f64[] bos
          bqj:f64[1] = mul 13.0:f64[] bfe
          bqk:f64[1] = add bqi bqj
          bql:f64[1] = sub bqk 13.0:f64[]
          bqm:f64[1000] = mul bql bop
          bqn:f64[1000] = add bqh bqm
          bqo:f64[1] = mul 2.0:f64[] bot
          bqp:f64[1] = mul 25.0:f64[] bos
          bqq:f64[1] = sub bqo bqp
          bqr:f64[1] = mul 72.0:f64[] bfe
          bqs:f64[1] = add bqq bqr
          bqt:f64[1] = sub bqs 61.0:f64[]
          bqu:f64[1000] = mul bqt boh
          bqv:f64[1000] = div bqu 2.0:f64[]
          bqw:f64[1000] = add bqn bqv
          bqx:f64[1] = mul 25.0:f64[] bot
          bqy:f64[1] = mul 195.0:f64[] bos
          bqz:f64[1] = sub bqx bqy
          bra:f64[1] = mul 477.0:f64[] bfe
          brb:f64[1] = add bqz bra
          brc:f64[1] = sub brb 379.0:f64[]
          brd:f64[1] = div brc 12.0:f64[]
          bre:f64[1000] = add bqw brd
          brf:f64[1000] = mul bqa bre
          brg:f64[1000] = add bof boh
          brh:f64[1000] = div brf bof
          bri:f64[1000] = add brh bpz
          brj:f64[1000] = div bri bof
          brk:f64[1000] = div bpg bof
          brl:f64[1000] = add brj brk
          brm:f64[1000] = add brl bow
          brn:f64[1000] = div brm bof
          bro:f64[1000] = add brg brn
          brp:f64[1000] = neg bfk
          brq:f64[1] = sub bfe 1.0:f64[]
          brr:f64[1000] = custom_jvp_call[
            name=xlogy
            call_jaxpr={ lambda ; brs:f64[1] brt:f64[1000]. let
                bru:bool[1] = ne brs 0.0:f64[]
                brv:f64[1000] = log brt
                brw:f64[1000] = mul brs brv
                brx:f64[1] = broadcast_in_dim[
                  broadcast_dimensions=()
                  shape=(1,)
                  sharding=None
                ] 0.0:f64[]
                bry:f64[1000] = pjit[name=_where jaxpr=_where3] bru brw brx
              in (bry,) }
            jvp=_xlogy_jvp
            symbolic_zeros=False
          ] brq bnt
          brz:f64[1000] = add brp brr
          bsa:f64[1] = sub 1.0:f64[] bfe
          bsb:f64[1000] = add 1.0:f64[] bnt
          bsc:f64[1000] = div bsa bsb
          bsd:f64[1000] = log1p bsc
          bse:f64[1000] = sub brz bsd
          bsf:f64[1000] = neg bfk
          bsg:f64[1] = sub bfe 1.0:f64[]
          bsh:f64[1000] = custom_jvp_call[
            name=xlogy
            call_jaxpr={ lambda ; bsi:f64[1] bsj:f64[1000]. let
                bsk:bool[1] = ne bsi 0.0:f64[]
                bsl:f64[1000] = log bsj
                bsm:f64[1000] = mul bsi bsl
                bsn:f64[1] = broadcast_in_dim[
                  broadcast_dimensions=()
                  shape=(1,)
                  sharding=None
                ] 0.0:f64[]
                bso:f64[1000] = pjit[name=_where jaxpr=_where3] bsk bsm bsn
              in (bso,) }
            jvp=_xlogy_jvp
            symbolic_zeros=False
          ] bsg bse
          bsp:f64[1000] = add bsf bsh
          bsq:f64[1] = sub 1.0:f64[] bfe
          bsr:f64[1000] = add 1.0:f64[] bse
          bss:f64[1000] = div bsq bsr
          bst:f64[1000] = log1p bss
          bsu:f64[1000] = sub bsp bst
          bsv:f64[1000] = pjit[name=_where jaxpr=_where] boe bro bsu
          bsw:f64[1] = mul 3.0:f64[] bfe
          bsx:bool[1000] = lt bnt bsw
          bsy:f64[1000] = pjit[name=_where jaxpr=_where] bsx bnt bsv
          bsz:bool[1] = ge bfe 500.0:f64[]
          bta:f64[1000] = div bnt bfe
          btb:f64[1000] = sub 1.0:f64[] bta
          btc:f64[1000] = abs btb
          btd:bool[1000] = lt btc 1e-06:f64[]
          bte:bool[1000] = and bsz btd
          btf:f64[1000] = pjit[name=_where jaxpr=_where] bte bnt bsy
          btg:f64[1000] = log bff
          bth:f64[1] = add bfe 1.0:f64[]
          bti:f64[1] = lgamma bth
          btj:f64[1000] = add btg bti
          btk:f64[1000] = add btj bnt
          btl:f64[1000] = div btk bfe
          btm:f64[1000] = exp btl
          btn:f64[1] = add bfe 1.0:f64[]
          bto:f64[1000] = div btm btn
          btp:f64[1] = add bfe 2.0:f64[]
          btq:f64[1000] = div btm btp
          btr:f64[1000] = add 1.0:f64[] btq
          bts:f64[1000] = mul bto btr
          btt:f64[1000] = log1p bts
          btu:f64[1000] = add btj btm
          btv:f64[1000] = sub btu btt
          btw:f64[1000] = div btv bfe
          btx:f64[1000] = exp btw
          bty:f64[1] = add bfe 1.0:f64[]
          btz:f64[1000] = div btx bty
          bua:f64[1] = add bfe 2.0:f64[]
          bub:f64[1000] = div btx bua
          buc:f64[1000] = add 1.0:f64[] bub
          bud:f64[1000] = mul btz buc
          bue:f64[1000] = log1p bud
          buf:f64[1000] = add btj btx
          bug:f64[1000] = sub buf bue
          buh:f64[1000] = div bug bfe
          bui:f64[1000] = exp buh
          buj:f64[1] = add bfe 1.0:f64[]
          buk:f64[1000] = div bui buj
          bul:f64[1] = add bfe 2.0:f64[]
          bum:f64[1000] = div bui bul
          bun:f64[1] = add bfe 3.0:f64[]
          buo:f64[1000] = div bui bun
          bup:f64[1000] = add 1.0:f64[] buo
          buq:f64[1000] = mul bum bup
          bur:f64[1000] = add 1.0:f64[] buq
          bus:f64[1000] = mul buk bur
          but:f64[1000] = log1p bus
          buu:f64[1000] = add btj bui
          buv:f64[1000] = sub buu but
          buw:f64[1000] = div buv bfe
          bux:f64[1000] = exp buw
          buy:f64[1] = add bfe 1.0:f64[]
          buz:f64[1] = mul 0.15:f64[] buy
          bva:bool[1000] = le bnt buz
          bvb:f64[1000] = pjit[name=_where jaxpr=_where] bva bux bnt
          bvc:bool[1000] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(1000,)
            sharding=None
          ] False:bool[]
          bvd:f64[1000] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(1000,)
            sharding=None
          ] 1.0:f64[]
          bve:f64[1000] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(1000,)
            sharding=None
          ] 1.0:f64[]
          _:bool[1000] _:f64[] _:f64[1000] bvf:f64[1000] = while[
            body_jaxpr={ lambda ; bvg:f64[1000] bvh:f64[1] bvi:bool[1000] bvj:f64[]
                bvk:f64[1000] bvl:f64[1000]. let
                bvm:f64[1000] = mul bvk bvg
                bvn:f64[1] = add bvh bvj
                bvo:f64[1000] = div bvm bvn
                bvp:f64[1000] = add bvl bvo
                bvq:f64[1000] = pjit[name=_where jaxpr=_where4] bvi bvl bvp
                bvr:bool[1000] = lt bvo 0.0001:f64[]
                bvs:bool[] = gt bvj 100.0:f64[]
                bvt:bool[1000] = or bvr bvs
                bvu:f64[] = add bvj 1.0:f64[]
              in (bvt, bvu, bvo, bvq) }
            body_nconsts=2
            cond_jaxpr={ lambda ; bvv:bool[1000] bvw:f64[] bvx:f64[1000] bvy:f64[1000]. let
                bvz:bool[1000] = not bvv
                bwa:bool[] = reduce_or[axes=(0,)] bvz
              in (bwa,) }
            cond_nconsts=0
          ] bvb bfe bvc 1.0:f64[] bvd bve
          bwb:f64[1000] = log bvf
          bwc:f64[1000] = add btj bvb
          bwd:f64[1000] = sub bwc bwb
          bwe:f64[1000] = div bwd bfe
          bwf:f64[1000] = exp bwe
          bwg:f64[1] = add bfe 1.0:f64[]
          bwh:f64[1] = mul 0.01:f64[] bwg
          bwi:bool[1000] = le bvb bwh
          bwj:f64[1] = add bfe 1.0:f64[]
          bwk:f64[1] = mul 0.7:f64[] bwj
          bwl:bool[1000] = gt bvb bwk
          bwm:bool[1000] = or bwi bwl
          bwn:f64[1000] = log bwf
          bwo:f64[1000] = mul bfe bwn
          bwp:f64[1000] = sub bwo bwf
          bwq:f64[1000] = sub bwp btj
          bwr:f64[1000] = add bwq bwb
          bws:f64[1000] = sub bfe bwf
          bwt:f64[1000] = div bwr bws
          bwu:f64[1000] = sub 1.0:f64[] bwt
          bwv:f64[1000] = mul bwf bwu
          bww:f64[1000] = pjit[name=_where jaxpr=_where] bwm bvb bwv
          bwx:bool[1000] = le bff 0.5:f64[]
          bwy:f64[1000] = pjit[name=_where jaxpr=_where] bwx bww btf
          bwz:bool[1] = lt bfe 1.0:f64[]
          bxa:f64[1000] = pjit[name=_where jaxpr=_where1] bwz blz bwy
          bxb:bool[1] = eq bfe 1.0:f64[]
          bxc:f64[1000] = neg bfj
          bxd:f64[1000] = pjit[name=_where jaxpr=_where1] bxb bxc bxa
          bxe:f64[1000] = log bxd
          bxf:f64[1000] = mul bfe bxe
          bxg:f64[1000] = sub bxf bxd
          bxh:f64[1] = lgamma bfe
          bxi:f64[1000] = sub bxg bxh
          bxj:f64[1000] = exp bxi
          bxk:bool[1000] = le bff 0.9:f64[]
          bxl:bool[1000] = and bxk True:bool[]
          bxm:bool[1000] = gt bfg 0.9:f64[]
          bxn:bool[1000] = and bxm False:bool[]
          bxo:bool[1000] = or bxl bxn
          bxp:f64[1000] = igamma bfe bxd
          bxq:f64[1000] = sub bxp bff
          bxr:f64[1000] = mul bxq bxd
          bxs:f64[1000] = div bxr bxj
          bxt:f64[1000] = igammac bfe bxd
          bxu:f64[1000] = sub bxt bfg
          bxv:f64[1000] = neg bxu
          bxw:f64[1000] = mul bxv bxd
          bxx:f64[1000] = div bxw bxj
          bxy:f64[1000] = pjit[name=_where jaxpr=_where] bxo bxs bxx
          bxz:f64[1] = sub bfe 1.0:f64[]
          bya:f64[1000] = div bxz bxd
          byb:f64[1000] = add -1.0:f64[] bya
          byc:bool[1000] = pjit[name=isinf jaxpr=isinf] byb
          byd:f64[1000] = sub bxd bxy
          bye:f64[1000] = mul 0.5:f64[] bxy
          byf:f64[1000] = mul bye byb
          byg:f64[1000] = sub 1.0:f64[] byf
          byh:f64[1000] = div bxy byg
          byi:f64[1000] = sub bxd byh
          byj:f64[1000] = pjit[name=_where jaxpr=_where] byc byd byi
          byk:bool[1000] = eq bxj 0.0:f64[]
          byl:f64[1000] = pjit[name=_where jaxpr=_where] byk bxd byj
          bym:f64[1000] = log byl
          byn:f64[1000] = mul bfe bym
          byo:f64[1000] = sub byn byl
          byp:f64[1] = lgamma bfe
          byq:f64[1000] = sub byo byp
          byr:f64[1000] = exp byq
          bys:bool[1000] = le bff 0.9:f64[]
          byt:bool[1000] = and bys True:bool[]
          byu:bool[1000] = gt bfg 0.9:f64[]
          byv:bool[1000] = and byu False:bool[]
          byw:bool[1000] = or byt byv
          byx:f64[1000] = igamma bfe byl
          byy:f64[1000] = sub byx bff
          byz:f64[1000] = mul byy byl
          bza:f64[1000] = div byz byr
          bzb:f64[1000] = igammac bfe byl
          bzc:f64[1000] = sub bzb bfg
          bzd:f64[1000] = neg bzc
          bze:f64[1000] = mul bzd byl
          bzf:f64[1000] = div bze byr
          bzg:f64[1000] = pjit[name=_where jaxpr=_where] byw bza bzf
          bzh:f64[1] = sub bfe 1.0:f64[]
          bzi:f64[1000] = div bzh byl
          bzj:f64[1000] = add -1.0:f64[] bzi
          bzk:bool[1000] = pjit[name=isinf jaxpr=isinf] bzj
          bzl:f64[1000] = sub byl bzg
          bzm:f64[1000] = mul 0.5:f64[] bzg
          bzn:f64[1000] = mul bzm bzj
          bzo:f64[1000] = sub 1.0:f64[] bzn
          bzp:f64[1000] = div bzg bzo
          bzq:f64[1000] = sub byl bzp
          bzr:f64[1000] = pjit[name=_where jaxpr=_where] bzk bzl bzq
          bzs:bool[1000] = eq byr 0.0:f64[]
          bzt:f64[1000] = pjit[name=_where jaxpr=_where] bzs byl bzr
          bzu:f64[1000] = log bzt
          bzv:f64[1000] = mul bfe bzu
          bzw:f64[1000] = sub bzv bzt
          bzx:f64[1] = lgamma bfe
          bzy:f64[1000] = sub bzw bzx
          bzz:f64[1000] = exp bzy
          caa:bool[1000] = le bff 0.9:f64[]
          cab:bool[1000] = and caa True:bool[]
          cac:bool[1000] = gt bfg 0.9:f64[]
          cad:bool[1000] = and cac False:bool[]
          cae:bool[1000] = or cab cad
          caf:f64[1000] = igamma bfe bzt
          cag:f64[1000] = sub caf bff
          cah:f64[1000] = mul cag bzt
          cai:f64[1000] = div cah bzz
          caj:f64[1000] = igammac bfe bzt
          cak:f64[1000] = sub caj bfg
          cal:f64[1000] = neg cak
          cam:f64[1000] = mul cal bzt
          can:f64[1000] = div cam bzz
          cao:f64[1000] = pjit[name=_where jaxpr=_where] cae cai can
          cap:f64[1] = sub bfe 1.0:f64[]
          caq:f64[1000] = div cap bzt
          car:f64[1000] = add -1.0:f64[] caq
          cas:bool[1000] = pjit[name=isinf jaxpr=isinf] car
          cat:f64[1000] = sub bzt cao
          cau:f64[1000] = mul 0.5:f64[] cao
          cav:f64[1000] = mul cau car
          caw:f64[1000] = sub 1.0:f64[] cav
          cax:f64[1000] = div cao caw
          cay:f64[1000] = sub bzt cax
          caz:f64[1000] = pjit[name=_where jaxpr=_where] cas cat cay
          cba:bool[1000] = eq bzz 0.0:f64[]
          cbb:f64[1000] = pjit[name=_where jaxpr=_where] cba bzt caz
          cbc:bool[1] = lt bfe 0.0:f64[]
          cbd:bool[1000] = lt bff 0.0:f64[]
          cbe:bool[1000] = or cbc cbd
          cbf:bool[1000] = gt bff 1.0:f64[]
          cbg:bool[1000] = or cbe cbf
          cbh:f64[1000] = pjit[name=_where jaxpr=_where2] cbg nan:f64[] cbb
          cbi:bool[1000] = eq bff 0.0:f64[]
          cbj:f64[1000] = pjit[name=_where jaxpr=_where2] cbi 0.0:f64[] cbh
          cbk:bool[1000] = eq bff 1.0:f64[]
          cbl:f64[1000] = pjit[name=_where jaxpr=_where2] cbk inf:f64[] cbj
        in (cbl,) }
      jvp=_igammainv_jvp
      num_consts=2
      symbolic_zeros=False
    ] dg dh bfa bez
    cbm:f64[1000] = mul 2.0:f64[] bfb
    cbn:f64[1,2,2] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 2, 2)
      sharding=None
    ] dr
    cbo:f64[1000] = div cbm bev
    cbp:f64[1000] = sqrt cbo
    cbq:f64[1000,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(1000, 1, 1)
      sharding=None
    ] cbp
    cbr:f64[1,2,2] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 2, 2)
      sharding=None
    ] dr
    cbs:f64[1000,2,2] = sub ho cbr
    cbt:f64[1000,2,2] = mul cbq cbs
    cbu:f64[1000,2,2] = add cbn cbt
    cbv:f64[4000,2,2] = concatenate[dimension=0] hk ho ber cbu
    cbw:f64[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=None] di
    cbx:f64[2] = squeeze[dimensions=(0,)] cbw
    cby:f64[2] cbz:f64[2,2] = pjit[
      name=eigh
      jaxpr={ lambda ; dj:f64[2,2]. let
          cca:f64[2,2] = transpose[permutation=(1, 0)] dj
          ccb:f64[2,2] = add dj cca
          ccc:f64[2,2] = div ccb 2.0:f64[]
          cbz:f64[2,2] cby:f64[2] = eigh[
            lower=True
            sort_eigenvalues=True
            subset_by_index=None
          ] ccc
        in (cby, cbz) }
    ] dj
    ccd:f64[2] = abs cby
    cce:f64[2] = sqrt ccd
    ccf:f64[2,2] = dot_general[
      dimension_numbers=(([], []), ([1], [0]))
      preferred_element_type=float64
    ] cbz cce
    ccg:f64[2,2] = pjit[
      name=qr
      jaxpr={ lambda ; ccf:f64[2,2]. let
          cch:f64[2,2] ccg:f64[2,2] = qr[
            full_matrices=True
            pivoting=False
            use_magma=None
          ] ccf
        in (ccg,) }
    ] ccf
    cci:u32[2,2] = iota[dimension=0 dtype=uint32 shape=(2, 2) sharding=None] 
    ccj:u32[2,2] = iota[dimension=1 dtype=uint32 shape=(2, 2) sharding=None] 
    cck:bool[2,2] = eq cci ccj
    ccl:f64[2,2] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2, 2)
      sharding=None
    ] 0.0:f64[]
    ccm:f64[2,2] = select_n cck ccl ccg
    ccn:f64[2] = reduce_sum[axes=(0,)] ccm
    cco:f64[2,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(2, 1)
      sharding=None
    ] ccn
    ccp:f64[2,1] = sign cco
    ccq:f64[2,2] = mul ccg ccp
    ccr:f64[2,2] = transpose[permutation=(1, 0)] ccq
    ccs:f64[4000,1,2] = slice[
      limit_indices=(4000, 1, 2)
      start_indices=(0, 0, 0)
      strides=None
    ] cbv
    cct:f64[4000,2] = squeeze[dimensions=(1,)] ccs
    ccu:f64[1,2] = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 2)
      sharding=None
    ] cbx
    ccv:f64[4000,2] = sub cct ccu
    ccw:f64[4000,2,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(4000, 2, 1)
      sharding=None
    ] ccv
    ccx:f64[2,2] = pjit[name=tril jaxpr=tril] ccr
    ccy:f64[2,1,4000] = pjit[
      name=_solve_triangular
      jaxpr={ lambda ; ccx:f64[2,2] ccw:f64[4000,2,1]. let
          ccz:f64[2,1,4000] = transpose[permutation=(1, 2, 0)] ccw
          cda:f64[2,4000] = reshape[
            dimensions=None
            new_sizes=(2, 4000)
            sharding=None
          ] ccz
          cdb:f64[2,4000] = triangular_solve[
            conjugate_a=False
            left_side=True
            lower=True
            transpose_a=False
            unit_diagonal=False
          ] ccx cda
          ccy:f64[2,1,4000] = reshape[
            dimensions=None
            new_sizes=(2, 1, 4000)
            sharding=None
          ] cdb
        in (ccy,) }
    ] ccx ccw
    cdc:f64[4000,2,1] = transpose[permutation=(2, 0, 1)] ccy
    cdd:f64[4000,2] = squeeze[dimensions=(2,)] cdc
    cde:f64[4000,2] = div cdd 1.0:f64[]
    cdf:f64[] = div 0.0:f64[] 1.0:f64[]
    cdg:f64[4000,2] = sub cde cdf
    cdh:f64[4000,2] = square cdg
    cdi:f64[4000,2] = mul -0.5:f64[] cdh
    cdj:f64[] = log 1.0:f64[]
    cdk:f64[] = add 0.9189385332046727:f64[] cdj
    cdl:f64[4000,2] = sub cdi cdk
    cdm:f64[4000] = reduce_sum[axes=(1,)] cdl
    cdn:f64[2] = pjit[name=diagonal jaxpr=diagonal] ccr
    cdo:f64[2] = abs cdn
    cdp:f64[2] = log cdo
    cdq:f64[] = reduce_sum[axes=(0,)] cdp
    cdr:f64[] = neg cdq
    cds:f64[] = neg cdr
    cdt:f64[] = mul 1.0:f64[] cds
    cdu:f64[] = reduce_sum[axes=()] cdt
    cdv:f64[] = add 0.0:f64[] cdu
    cdw:f64[] = neg 0.0:f64[]
    cdx:f64[] = neg cdw
    cdy:f64[2] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2,)
      sharding=None
    ] 1.0:f64[]
    cdz:f64[2] = mul cdy cdx
    cea:f64[] = reduce_sum[axes=(0,)] cdz
    ceb:f64[] = add cdv cea
    cec:f64[] = copy ceb
    ced:f64[4000] = sub cdm cec
    cee:f64[4000,1,2] = slice[
      limit_indices=(4000, 1, 2)
      start_indices=(0, 0, 0)
      strides=None
    ] cbv
    cef:f64[1,1,2] = transpose[permutation=(0, 2, 1)] dk
    ceg:f64[4000,1,2] = slice[
      limit_indices=(4000, 2, 2)
      start_indices=(0, 1, 0)
      strides=None
    ] cbv
    ceh:f64[1,2] = slice[limit_indices=(2, 2) start_indices=(1, 0) strides=None] di
    cei:f64[1,1,2] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 1, 2)
      sharding=None
    ] ceh
    cej:f64[4000,1,2] = sub ceg cei
    cek:f64[4000,1,2,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1, 2)
      shape=(4000, 1, 2, 1)
      sharding=None
    ] cee
    cel:f64[2,2] = squeeze[dimensions=(0,)] dl
    cem:f64[2,4000,1,1] = dot_general[
      dimension_numbers=(([1], [2]), ([], []))
      preferred_element_type=float64
    ] cel cek
    cen:f64[4000,1,2,1] = transpose[permutation=(1, 2, 0, 3)] cem
    ceo:f64[4000,1,2,1] = slice[
      limit_indices=(4000, 1, 2, 1)
      start_indices=(0, 0, 0, 0)
      strides=None
    ] cen
    cep:f64[4000,1,2] = squeeze[dimensions=(3,)] ceo
    ceq:f64[4000,1,2] = sub cej cep
    cer:f64[1,1,4000] = dot_general[
      dimension_numbers=(([2], [2]), ([0], [1]))
      preferred_element_type=float64
    ] cef ceq
    ces:f64[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] 0.0:f64[]
    cet:f64[1,1] ceu:f64[1,1,1] = pjit[
      name=eigh
      jaxpr={ lambda ; dm:f64[1,1,1]. let
          cev:f64[1,1,1] = transpose[permutation=(0, 2, 1)] dm
          cew:f64[1,1,1] = add dm cev
          cex:f64[1,1,1] = div cew 2.0:f64[]
          ceu:f64[1,1,1] cet:f64[1,1] = eigh[
            lower=True
            sort_eigenvalues=True
            subset_by_index=None
          ] cex
        in (cet, ceu) }
    ] dm
    cey:f64[1,1] = abs cet
    cez:f64[1,1] = sqrt cey
    cfa:f64[1,1,1] = dot_general[
      dimension_numbers=(([], []), ([0, 2], [0, 1]))
      preferred_element_type=float64
    ] ceu cez
    cfb:f64[1,1,1] = pjit[
      name=qr
      jaxpr={ lambda ; cfa:f64[1,1,1]. let
          cfc:f64[1,1,1] cfb:f64[1,1,1] = qr[
            full_matrices=True
            pivoting=False
            use_magma=None
          ] cfa
        in (cfb,) }
    ] cfa
    cfd:u32[1,1] = iota[dimension=0 dtype=uint32 shape=(1, 1) sharding=None] 
    cfe:u32[1,1] = iota[dimension=1 dtype=uint32 shape=(1, 1) sharding=None] 
    cff:bool[1,1] = eq cfd cfe
    cfg:bool[1,1,1] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 1, 1)
      sharding=None
    ] cff
    cfh:f64[1,1,1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1, 1, 1)
      sharding=None
    ] 0.0:f64[]
    cfi:f64[1,1,1] = select_n cfg cfh cfb
    cfj:f64[1,1] = reduce_sum[axes=(1,)] cfi
    cfk:f64[1,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(1, 1, 1)
      sharding=None
    ] cfj
    cfl:f64[1,1,1] = sign cfk
    cfm:f64[1,1,1] = mul cfb cfl
    cfn:f64[1,1,1] = transpose[permutation=(0, 2, 1)] cfm
    cfo:f64[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] 0.0:f64[]
    cfp:f64[1,1] = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 1)
      sharding=None
    ] ces
    cfq:f64[4000,1,1] = transpose[permutation=(2, 0, 1)] cer
    cfr:f64[1,1,1] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 1, 1)
      sharding=None
    ] cfp
    cfs:f64[4000,1,1] = sub cfq cfr
    cft:f64[4000,1,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1, 2)
      shape=(4000, 1, 1, 1)
      sharding=None
    ] cfs
    cfu:f64[1,1,1] = pjit[
      name=tril
      jaxpr={ lambda ; cfn:f64[1,1,1]. let
          cfv:i32[1,1] = iota[
            dimension=0
            dtype=int32
            shape=(1, 1)
            sharding=None
          ] 
          cfw:i32[1,1] = add cfv 0:i32[]
          cfx:i32[1,1] = iota[
            dimension=1
            dtype=int32
            shape=(1, 1)
            sharding=None
          ] 
          cfy:bool[1,1] = ge cfw cfx
          cfz:bool[1,1,1] = broadcast_in_dim[
            broadcast_dimensions=(1, 2)
            shape=(1, 1, 1)
            sharding=None
          ] cfy
          cga:f64[1,1,1] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(1, 1, 1)
            sharding=None
          ] 0.0:f64[]
          cfu:f64[1,1,1] = select_n cfz cga cfn
        in (cfu,) }
    ] cfn
    cgb:f64[1,1,1,4000] = pjit[
      name=_solve_triangular
      jaxpr={ lambda ; cfu:f64[1,1,1] cft:f64[4000,1,1,1]. let
          cgc:f64[1,1,1,4000] = transpose[permutation=(1, 2, 3, 0)] cft
          cgd:f64[1,1,4000] = reshape[
            dimensions=None
            new_sizes=(1, 1, 4000)
            sharding=None
          ] cgc
          cge:f64[1,1,4000] = triangular_solve[
            conjugate_a=False
            left_side=True
            lower=True
            transpose_a=False
            unit_diagonal=False
          ] cfu cgd
          cgb:f64[1,1,1,4000] = reshape[
            dimensions=None
            new_sizes=(1, 1, 1, 4000)
            sharding=None
          ] cge
        in (cgb,) }
    ] cfu cft
    cgf:f64[4000,1,1,1] = transpose[permutation=(3, 0, 1, 2)] cgb
    cgg:f64[4000,1,1] = squeeze[dimensions=(3,)] cgf
    cgh:f64[4000,1,1] = transpose[permutation=(0, 2, 1)] cgg
    cgi:f64[4000,1,1] = div cgh 1.0:f64[]
    cgj:f64[1] = div cfo 1.0:f64[]
    cgk:f64[1,1] = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 1)
      sharding=None
    ] cgj
    cgl:f64[1,1,1] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 1, 1)
      sharding=None
    ] cgk
    cgm:f64[4000,1,1] = sub cgi cgl
    cgn:f64[4000,1,1] = square cgm
    cgo:f64[4000,1,1] = mul -0.5:f64[] cgn
    cgp:f64[] = log 1.0:f64[]
    cgq:f64[] = add 0.9189385332046727:f64[] cgp
    cgr:f64[4000,1,1] = sub cgo cgq
    cgs:f64[4000,1] = reduce_sum[axes=(1,)] cgr
    cgt:f64[1,1] = pjit[
      name=diagonal
      jaxpr={ lambda ; cfn:f64[1,1,1]. let
          cgu:i64[1] = iota[dimension=0 dtype=int64 shape=(1,) sharding=None] 
          cgv:i64[1] = iota[dimension=0 dtype=int64 shape=(1,) sharding=None] 
          cgw:bool[1] = lt cgu 0:i64[]
          cgx:i64[1] = add cgu 1:i64[]
          cgy:i64[1] = select_n cgw cgu cgx
          cgz:bool[1] = lt cgv 0:i64[]
          cha:i64[1] = add cgv 1:i64[]
          chb:i64[1] = select_n cgz cgv cha
          chc:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] cgy
          chd:i32[1] = convert_element_type[new_dtype=int32 weak_type=False] chb
          che:i32[1,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(1, 1)
            sharding=None
          ] chc
          chf:i32[1,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(1, 1)
            sharding=None
          ] chd
          chg:i32[1,2] = concatenate[dimension=1] che chf
          cgt:f64[1,1] = gather[
            dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(1, 2), start_index_map=(1, 2), operand_batching_dims=(), start_indices_batching_dims=())
            fill_value=None
            indices_are_sorted=False
            mode=GatherScatterMode.PROMISE_IN_BOUNDS
            slice_sizes=(1, 1, 1)
            unique_indices=False
          ] cfn chg
        in (cgt,) }
    ] cfn
    chh:f64[1,1] = abs cgt
    chi:f64[1,1] = log chh
    chj:f64[1] = reduce_sum[axes=(1,)] chi
    chk:f64[1] = neg chj
    chl:f64[1] = neg chk
    chm:f64[1] = mul 1.0:f64[] chl
    chn:f64[1] = reduce_sum[axes=()] chm
    cho:f64[1] = add 0.0:f64[] chn
    chp:f64[] = neg 0.0:f64[]
    chq:f64[] = neg chp
    chr:f64[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] 1.0:f64[]
    chs:f64[1] = mul chr chq
    cht:f64[] = reduce_sum[axes=(0,)] chs
    chu:f64[1] = add cho cht
    chv:f64[1] = copy chu
    chw:f64[1,1] = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 1)
      sharding=None
    ] chv
    chx:f64[4000,1] = sub cgs chw
    chy:f64[4000,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(4000, 1)
      sharding=None
    ] ced
    chz:f64[4000,2] = concatenate[dimension=1] chy chx
    cia:f64[4000] = reduce_sum[axes=(1,)] chz
    cib:f64[4000,2,2,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1, 2)
      shape=(4000, 2, 2, 1)
      sharding=None
    ] cbv
    cic:f64[2,1,4000,1] = dot_general[
      dimension_numbers=(([2], [2]), ([0], [1]))
      preferred_element_type=float64
    ] dn cib
    cid:f64[2,1,4000,1] = slice[
      limit_indices=(2, 1, 4000, 1)
      start_indices=(0, 0, 0, 0)
      strides=None
    ] cic
    cie:f64[4000,2,1,1] = transpose[permutation=(2, 0, 1, 3)] cid
    cif:f64[4000,2,1] = squeeze[dimensions=(3,)] cie
    cig:f64[1,2,1] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 2, 1)
      sharding=None
    ] do
    cih:f64[4000,2,1] = add cig cif
    cii:f64[4000,2,1] = copy cih
    cij:f64[2,1] = copy dp
    cik:f64[2,1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2, 1)
      sharding=None
    ] 0.0:f64[]
    cil:f64[2,1] = custom_jvp_call[
      name=_maximum_
      call_jaxpr={ lambda ; cim:f64[2,1] cin:f64[2,1]. let
          cio:f64[2,1] = max cim cin
        in (cio,) }
      jvp=_maximum_jvp
      symbolic_zeros=False
    ] cij cik
    cip:bool[2,1] = eq cil 0.0:f64[]
    ciq:f64[1,2,1] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 2, 1)
      sharding=None
    ] cil
    cir:f64[4000,2,1] = mul cii ciq
    cis:f64[4000,2,1] = pjit[
      name=_where
      jaxpr={ lambda ; cip:bool[2,1] cit:f64[] cir:f64[4000,2,1]. let
          ciu:f64[2,1] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(2, 1)
            sharding=None
          ] cit
          civ:bool[4000,2,1] = broadcast_in_dim[
            broadcast_dimensions=(1, 2)
            shape=(4000, 2, 1)
            sharding=None
          ] cip
          ciw:f64[4000,2,1] = broadcast_in_dim[
            broadcast_dimensions=(1, 2)
            shape=(4000, 2, 1)
            sharding=None
          ] ciu
          cis:f64[4000,2,1] = select_n civ cir ciw
        in (cis,) }
    ] cip 0.0:f64[] cir
    cix:f64[2,1] = add 1.0:f64[] cil
    ciy:f64[2,1] = lgamma cix
    ciz:f64[1,2,1] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 2, 1)
      sharding=None
    ] ciy
    cja:f64[4000,2,1] = sub cis ciz
    cjb:bool[2,1] = eq dp cil
    cjc:f64[4000,2,1] = pjit[
      name=_where
      jaxpr={ lambda ; cjb:bool[2,1] cja:f64[4000,2,1] cjd:f64[]. let
          cje:f64[2,1] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(2, 1)
            sharding=None
          ] cjd
          cjf:bool[4000,2,1] = broadcast_in_dim[
            broadcast_dimensions=(1, 2)
            shape=(4000, 2, 1)
            sharding=None
          ] cjb
          cjg:f64[4000,2,1] = broadcast_in_dim[
            broadcast_dimensions=(1, 2)
            shape=(4000, 2, 1)
            sharding=None
          ] cje
          cjc:f64[4000,2,1] = select_n cjf cjg cja
        in (cjc,) }
    ] cjb cja -inf:f64[]
    cjh:f64[4000,2,1] = exp cii
    cji:f64[4000,2,1] = sub cjc cjh
    cjj:f64[4000,2] = reduce_sum[axes=(2,)] cji
    cjk:f64[4000] = reduce_sum[axes=(1,)] cjj
    cjl:f64[4000] = add cia cjk
    cjm:f64[1,2,2] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 2, 2)
      sharding=None
    ] dr
    cjn:f64[4000,2,2] = sub cbv cjm
    cjo:f64[4000,1,2] = slice[
      limit_indices=(4000, 2, 2)
      start_indices=(0, 1, 0)
      strides=None
    ] cjn
    cjp:f64[4000,1,2] = slice[
      limit_indices=(4000, 1, 2)
      start_indices=(0, 0, 0)
      strides=None
    ] cjn
    cjq:f64[4000,1,2] = pjit[
      name=_solve_triangular
      jaxpr={ lambda ; dt:f64[1,2,2] cjp:f64[4000,1,2]. let
          cjr:f64[4000,1,2,1] = broadcast_in_dim[
            broadcast_dimensions=(0, 1, 2)
            shape=(4000, 1, 2, 1)
            sharding=None
          ] cjp
          cjs:f64[1,2,1,4000] = transpose[permutation=(1, 2, 3, 0)] cjr
          cjt:f64[1,2,4000] = reshape[
            dimensions=None
            new_sizes=(1, 2, 4000)
            sharding=None
          ] cjs
          cju:f64[1,2,4000] = triangular_solve[
            conjugate_a=False
            left_side=True
            lower=True
            transpose_a=False
            unit_diagonal=False
          ] dt cjt
          cjv:f64[1,2,1,4000] = reshape[
            dimensions=None
            new_sizes=(1, 2, 1, 4000)
            sharding=None
          ] cju
          cjw:f64[1,2,1,4000] = slice[
            limit_indices=(1, 2, 1, 4000)
            start_indices=(0, 0, 0, 0)
            strides=None
          ] cjv
          cjx:f64[4000,1,2,1] = transpose[permutation=(3, 0, 1, 2)] cjw
          cjq:f64[4000,1,2] = squeeze[dimensions=(3,)] cjx
        in (cjq,) }
    ] dt cjp
    cjy:f64[1,2,4000] = dot_general[
      dimension_numbers=(([2], [2]), ([0], [1]))
      preferred_element_type=float64
    ] du cjq
    cjz:f64[4000,1,2] = transpose[permutation=(2, 0, 1)] cjy
    cka:f64[4000,1,2] = sub cjo cjz
    ckb:f64[4000,1,2] = slice[
      limit_indices=(4000, 1, 2)
      start_indices=(0, 0, 0)
      strides=None
    ] cjn
    ckc:f64[4000,2,2] = concatenate[dimension=1] ckb cka
    ckd:f64[4000,2,2] = pjit[
      name=_solve_triangular
      jaxpr={ lambda ; ds:f64[2,2,2] ckc:f64[4000,2,2]. let
          cke:f64[4000,2,2,1] = broadcast_in_dim[
            broadcast_dimensions=(0, 1, 2)
            shape=(4000, 2, 2, 1)
            sharding=None
          ] ckc
          ckf:f64[2,2,1,4000] = transpose[permutation=(1, 2, 3, 0)] cke
          ckg:f64[2,2,4000] = reshape[
            dimensions=None
            new_sizes=(2, 2, 4000)
            sharding=None
          ] ckf
          ckh:f64[2,2,4000] = triangular_solve[
            conjugate_a=False
            left_side=True
            lower=True
            transpose_a=False
            unit_diagonal=False
          ] ds ckg
          cki:f64[2,2,1,4000] = reshape[
            dimensions=None
            new_sizes=(2, 2, 1, 4000)
            sharding=None
          ] ckh
          ckj:f64[2,2,1,4000] = slice[
            limit_indices=(2, 2, 1, 4000)
            start_indices=(0, 0, 0, 0)
            strides=None
          ] cki
          ckk:f64[4000,2,2,1] = transpose[permutation=(3, 0, 1, 2)] ckj
          ckd:f64[4000,2,2] = squeeze[dimensions=(3,)] ckk
        in (ckd,) }
    ] ds ckc
    ckl:f64[2] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2,)
      sharding=None
    ] 0.0:f64[]
    ckm:i64[2,2] = iota[dimension=0 dtype=int64 shape=(2, 2) sharding=None] 
    ckn:i64[2,2] = iota[dimension=1 dtype=int64 shape=(2, 2) sharding=None] 
    cko:i64[2,2] = add ckm 0:i64[]
    ckp:bool[2,2] = eq cko ckn
    ckq:f64[2,2] = convert_element_type[new_dtype=float64 weak_type=False] ckp
    ckr:f64[2,2] = pjit[
      name=cholesky
      jaxpr={ lambda ; ckq:f64[2,2]. let
          cks:f64[2,2] = transpose[permutation=(1, 0)] ckq
          ckt:f64[2,2] = add ckq cks
          cku:f64[2,2] = div ckt 2.0:f64[]
          ckv:f64[2,2] = cholesky cku
          ckw:i32[2,2] = iota[
            dimension=0
            dtype=int32
            shape=(2, 2)
            sharding=None
          ] 
          ckx:i32[2,2] = add ckw 0:i32[]
          cky:i32[2,2] = iota[
            dimension=1
            dtype=int32
            shape=(2, 2)
            sharding=None
          ] 
          ckz:bool[2,2] = ge ckx cky
          cla:f64[2,2] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(2, 2)
            sharding=None
          ] 0.0:f64[]
          ckr:f64[2,2] = select_n ckz cla ckv
        in (ckr,) }
    ] ckq
    clb:f64[1,2] = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 2)
      sharding=None
    ] ckl
    clc:f64[1,1,2] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 1, 2)
      sharding=None
    ] clb
    cld:f64[4000,2,2] = sub ckd clc
    cle:f64[4000,2,2,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1, 2)
      shape=(4000, 2, 2, 1)
      sharding=None
    ] cld
    clf:f64[2,2] = pjit[name=tril jaxpr=tril] ckr
    clg:f64[4000,2,1,2] = transpose[permutation=(0, 2, 3, 1)] cle
    clh:f64[4000,2,2] = reshape[
      dimensions=None
      new_sizes=(4000, 2, 2)
      sharding=None
    ] clg
    cli:f64[2,2,4000] = pjit[
      name=_solve_triangular
      jaxpr={ lambda ; clf:f64[2,2] clh:f64[4000,2,2]. let
          clj:f64[2,2,4000] = transpose[permutation=(1, 2, 0)] clh
          clk:f64[2,8000] = reshape[
            dimensions=None
            new_sizes=(2, 8000)
            sharding=None
          ] clj
          cll:f64[2,8000] = triangular_solve[
            conjugate_a=False
            left_side=True
            lower=True
            transpose_a=False
            unit_diagonal=False
          ] clf clk
          cli:f64[2,2,4000] = reshape[
            dimensions=None
            new_sizes=(2, 2, 4000)
            sharding=None
          ] cll
        in (cli,) }
    ] clf clh
    clm:f64[4000,2,2] = transpose[permutation=(2, 0, 1)] cli
    cln:f64[4000,2,1,2] = reshape[
      dimensions=None
      new_sizes=(4000, 2, 1, 2)
      sharding=None
    ] clm
    clo:f64[4000,2,2,1] = transpose[permutation=(0, 3, 1, 2)] cln
    clp:f64[4000,2,2] = squeeze[dimensions=(3,)] clo
    clq:f64[4000,2,2] = div clp 1.0:f64[]
    clr:f64[] = div 0.0:f64[] 1.0:f64[]
    cls:f64[4000,2,2] = sub clq clr
    clt:f64[4000,2,2] = square cls
    clu:f64[4000,2,2] = mul -0.5:f64[] clt
    clv:f64[] = log 1.0:f64[]
    clw:f64[] = add 0.9189385332046727:f64[] clv
    clx:f64[4000,2,2] = sub clu clw
    cly:f64[4000,2] = reduce_sum[axes=(2,)] clx
    clz:f64[2] = pjit[name=diagonal jaxpr=diagonal] ckr
    cma:f64[2] = abs clz
    cmb:f64[2] = log cma
    cmc:f64[] = reduce_sum[axes=(0,)] cmb
    cmd:f64[] = neg cmc
    cme:f64[] = neg cmd
    cmf:f64[] = mul 1.0:f64[] cme
    cmg:f64[] = reduce_sum[axes=()] cmf
    cmh:f64[] = add 0.0:f64[] cmg
    cmi:f64[] = neg 0.0:f64[]
    cmj:f64[] = neg cmi
    cmk:f64[2] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2,)
      sharding=None
    ] 1.0:f64[]
    cml:f64[2] = mul cmk cmj
    cmm:f64[] = reduce_sum[axes=(0,)] cml
    cmn:f64[] = add cmh cmm
    cmo:f64[] = copy cmn
    cmp:f64[4000,2] = sub cly cmo
    cmq:f64[4000] = reduce_sum[axes=(1,)] cmp
    cmr:f64[2,2] = pjit[
      name=_diag
      jaxpr={ lambda ; ds:f64[2,2,2]. let
          cmr:f64[2,2] = pjit[
            name=diagonal
            jaxpr={ lambda ; ds:f64[2,2,2]. let
                cms:i64[2,2] = iota[
                  dimension=0
                  dtype=int64
                  shape=(2, 2)
                  sharding=None
                ] 
                cmt:i64[2,2] = iota[
                  dimension=1
                  dtype=int64
                  shape=(2, 2)
                  sharding=None
                ] 
                cmu:i64[2,2] = add cms 0:i64[]
                cmv:bool[2,2] = eq cmu cmt
                cmw:f64[2,2] = convert_element_type[
                  new_dtype=float64
                  weak_type=False
                ] cmv
                cmx:i32[] = platform_index[platforms=(('mosaic',), None)] 
                cmr:f64[2,2] = cond[
                  branches=(
                    { lambda ; cmy:f64[2,2] cmz:f64[2,2,2]. let
                        cna:f64[1,2,2] = broadcast_in_dim[
                          broadcast_dimensions=(1, 2)
                          shape=(1, 2, 2)
                          sharding=None
                        ] cmy
                        cnb:f64[2,2,2] = mul cna cmz
                        cnc:f64[2,2] = reduce_sum[axes=(1,)] cnb
                      in (cnc,) }
                    { lambda ; cnd:f64[2,2] cne:f64[2,2,2]. let
                        cnf:i64[2] = iota[
                          dimension=0
                          dtype=int64
                          shape=(2,)
                          sharding=None
                        ] 
                        cng:i64[2] = iota[
                          dimension=0
                          dtype=int64
                          shape=(2,)
                          sharding=None
                        ] 
                        cnh:bool[2] = lt cnf 0:i64[]
                        cni:i64[2] = add cnf 2:i64[]
                        cnj:i64[2] = select_n cnh cnf cni
                        cnk:bool[2] = lt cng 0:i64[]
                        cnl:i64[2] = add cng 2:i64[]
                        cnm:i64[2] = select_n cnk cng cnl
                        cnn:i32[2] = convert_element_type[
                          new_dtype=int32
                          weak_type=False
                        ] cnj
                        cno:i32[2] = convert_element_type[
                          new_dtype=int32
                          weak_type=False
                        ] cnm
                        cnp:i32[2,1] = broadcast_in_dim[
                          broadcast_dimensions=(0,)
                          shape=(2, 1)
                          sharding=None
                        ] cnn
                        cnq:i32[2,1] = broadcast_in_dim[
                          broadcast_dimensions=(0,)
                          shape=(2, 1)
                          sharding=None
                        ] cno
                        cnr:i32[2,2] = concatenate[dimension=1] cnp cnq
                        cns:f64[2,2] = gather[
                          dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(1, 2), start_index_map=(1, 2), operand_batching_dims=(), start_indices_batching_dims=())
                          fill_value=None
                          indices_are_sorted=False
                          mode=GatherScatterMode.PROMISE_IN_BOUNDS
                          slice_sizes=(2, 1, 1)
                          unique_indices=False
                        ] cne cnr
                      in (cns,) }
                  )
                  branches_platforms=(('mosaic',), None)
                ] cmx cmw ds
              in (cmr,) }
          ] ds
        in (cmr,) }
    ] ds
    cnt:f64[2,2] = log cmr
    cnu:f64[] = reduce_sum[axes=(0, 1)] cnt
    cnv:f64[4000] = sub cmq cnu
    cnw:f64[4000] = sub cjl cnv
    cnx:f64[] = reduce_max[axes=(0,)] cnw
    cny:f64[4000] = sub cnw cnx
    cnz:f64[4000] = exp cny
    coa:f64[] = reduce_sum[axes=(0,)] cnz
    cob:f64[4000] = div cnz coa
    coc:f64[4000,1,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(4000, 1, 1)
      sharding=None
    ] cob
    cod:f64[4000,2,2] = mul coc cbv
    coe:f64[2,2] = reduce_sum[axes=(0,)] cod
    cof:f64[4000,1,2] = slice[
      limit_indices=(4000, 1, 2)
      start_indices=(0, 0, 0)
      strides=None
    ] cbv
    cog:f64[4000,1,2] = slice[
      limit_indices=(4000, 2, 2)
      start_indices=(0, 1, 0)
      strides=None
    ] cbv
    coh:f64[4000,1,4] = concatenate[dimension=2] cof cog
    coi:f64[1,4,4] = pjit[
      name=cov
      jaxpr={ lambda ; coh:f64[4000,1,4] cob:f64[4000]. let
          coj:f64[4000,1,4] = pjit[
            name=atleast_2d
            jaxpr={ lambda ; coh:f64[4000,1,4]. let  in (coh,) }
          ] coh
          cok:f64[1,4,4000] = transpose[permutation=(1, 2, 0)] coj
          col:f64[4000] = abs cob
          com:f64[1,4000] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 4000)
            sharding=None
          ] col
          con:f64[1] = reduce_sum[axes=(1,)] com
          coo:f64[1,1,4000] = broadcast_in_dim[
            broadcast_dimensions=(1, 2)
            shape=(1, 1, 4000)
            sharding=None
          ] com
          cop:f64[1,4,4000] = mul cok coo
          coq:f64[1,4] = reduce_sum[axes=(2,)] cop
          cor:f64[1,1] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 1)
            sharding=None
          ] con
          cos:f64[1,4] = div coq cor
          cot:f64[4] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(4,)
            sharding=None
          ] con
          cou:f64[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] cot
          cov:f64[] = squeeze[dimensions=(0,)] cou
          cow:f64[4000] = mul col col
          cox:f64[] = reduce_sum[axes=(0,)] cow
          coy:f64[] = mul 1.0:f64[] cox
          coz:f64[] = div coy cov
          cpa:f64[] = sub cov coz
          cpb:f64[1,4,1] = broadcast_in_dim[
            broadcast_dimensions=(0, 1)
            shape=(1, 4, 1)
            sharding=None
          ] cos
          cpc:f64[1,4,4000] = sub cok cpb
          cpd:f64[1,4000] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 4000)
            sharding=None
          ] col
          cpe:f64[1,1,4000] = broadcast_in_dim[
            broadcast_dimensions=(1, 2)
            shape=(1, 1, 4000)
            sharding=None
          ] cpd
          cpf:f64[1,4,4000] = mul cpc cpe
          cpg:f64[1,4000,4] = transpose[permutation=(0, 2, 1)] cpf
          cph:f64[1,4,4] = dot_general[
            dimension_numbers=(([2], [1]), ([0], [0]))
            preferred_element_type=float64
          ] cpc cpg
          coi:f64[1,4,4] = div cph cpa
        in (coi,) }
    ] coh cob
    cpi:f64[2,2] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2, 2)
      sharding=None
    ] 0.0:f64[]
    cpj:f64[1,4,4] = slice[
      limit_indices=(1, 4, 4)
      start_indices=(0, 0, 0)
      strides=None
    ] coi
    cpk:f64[4,4] = squeeze[dimensions=(0,)] cpj
    cpl:f64[2,2] = slice[limit_indices=(2, 2) start_indices=(0, 0) strides=None] cpk
    cpm:f64[4,4] = pjit[
      name=block_diag
      jaxpr={ lambda ; cpi:f64[2,2] cpl:f64[2,2]. let
          cpn:f64[2,2] = pjit[name=atleast_2d jaxpr=atleast_2d] cpi
          cpo:f64[2,2] = pjit[name=atleast_2d jaxpr=atleast_2d] cpl
          cpp:f64[2,4] = pad[padding_config=((0, 0, 0), (2, 0, 0))] cpo 0.0:f64[]
          cpq:f64[2,4] = pad[padding_config=((0, 0, 0), (0, 2, 0))] cpn 0.0:f64[]
          cpm:f64[4,4] = concatenate[dimension=0] cpq cpp
        in (cpm,) }
    ] cpi cpl
    cpr:f64[1,4,4] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 4, 4)
      sharding=None
    ] cpm
    cps:f64[2,4,4] = concatenate[dimension=0] cpr coi
    cpt:f64[2,4] cpu:f64[2,4,4] = pjit[
      name=eigh
      jaxpr={ lambda ; cps:f64[2,4,4]. let
          cpv:f64[2,4,4] = transpose[permutation=(0, 2, 1)] cps
          cpw:f64[2,4,4] = add cps cpv
          cpx:f64[2,4,4] = div cpw 2.0:f64[]
          cpu:f64[2,4,4] cpt:f64[2,4] = eigh[
            lower=True
            sort_eigenvalues=True
            subset_by_index=None
          ] cpx
        in (cpt, cpu) }
    ] cps
    cpy:f64[2,4] = abs cpt
    cpz:f64[2,4] = sqrt cpy
    cqa:f64[2,4,4] = dot_general[
      dimension_numbers=(([], []), ([0, 2], [0, 1]))
      preferred_element_type=float64
    ] cpu cpz
    cqb:f64[2,4,4] = pjit[
      name=qr
      jaxpr={ lambda ; cqa:f64[2,4,4]. let
          cqc:f64[2,4,4] cqb:f64[2,4,4] = qr[
            full_matrices=True
            pivoting=False
            use_magma=None
          ] cqa
        in (cqb,) }
    ] cqa
    cqd:u32[4,4] = iota[dimension=0 dtype=uint32 shape=(4, 4) sharding=None] 
    cqe:u32[4,4] = iota[dimension=1 dtype=uint32 shape=(4, 4) sharding=None] 
    cqf:bool[4,4] = eq cqd cqe
    cqg:bool[2,4,4] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(2, 4, 4)
      sharding=None
    ] cqf
    cqh:f64[2,4,4] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2, 4, 4)
      sharding=None
    ] 0.0:f64[]
    cqi:f64[2,4,4] = select_n cqg cqh cqb
    cqj:f64[2,4] = reduce_sum[axes=(1,)] cqi
    cqk:f64[2,4,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(2, 4, 1)
      sharding=None
    ] cqj
    cql:f64[2,4,1] = sign cqk
    cqm:f64[2,4,4] = mul cqb cql
    cqn:f64[2,4,4] = transpose[permutation=(0, 2, 1)] cqm
    cqo:f64[1,2,2] = slice[
      limit_indices=(2, 2, 2)
      start_indices=(1, 0, 0)
      strides=None
    ] cqn
    cqp:f64[1,2,2] = slice[
      limit_indices=(2, 4, 2)
      start_indices=(1, 2, 0)
      strides=None
    ] cqn
    cqq:f64[2,2,2] = slice[
      limit_indices=(2, 4, 4)
      start_indices=(0, 2, 2)
      strides=None
    ] cqn
  in (dw, coe, cqq, cqo, cqp, cnw) }, ())

During handling of the above exception, another exception occurred:

JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File ~/.local/share/uv/python/cpython-3.10.17-macos-aarch64-none/lib/python3.10/runpy.py:196, in _run_module_as_main()
    195     sys.argv[0] = mod_spec.origin
--> 196 return _run_code(code, main_globals, None,
    197                  "__main__", mod_spec)

File ~/.local/share/uv/python/cpython-3.10.17-macos-aarch64-none/lib/python3.10/runpy.py:86, in _run_code()
     79 run_globals.update(__name__ = mod_name,
     80                    __file__ = fname,
     81                    __cached__ = cached,
   (...)
     84                    __package__ = pkg_name,
     85                    __spec__ = mod_spec)
---> 86 exec(code, run_globals)
     87 return run_globals

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel_launcher.py:18
     16 from ipykernel import kernelapp as app
---> 18 app.launch_new_instance()

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/traitlets/config/application.py:1075, in launch_instance()
   1074 app.initialize(argv)
-> 1075 app.start()

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/kernelapp.py:739, in start()
    738 try:
--> 739     self.io_loop.start()
    740 except KeyboardInterrupt:

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tornado/platform/asyncio.py:211, in start()
    210 def start(self) -> None:
--> 211     self.asyncio_loop.run_forever()

File ~/.local/share/uv/python/cpython-3.10.17-macos-aarch64-none/lib/python3.10/asyncio/base_events.py:603, in run_forever()
    602 while True:
--> 603     self._run_once()
    604     if self._stopping:

File ~/.local/share/uv/python/cpython-3.10.17-macos-aarch64-none/lib/python3.10/asyncio/base_events.py:1909, in _run_once()
   1908     else:
-> 1909         handle._run()
   1910 handle = None

File ~/.local/share/uv/python/cpython-3.10.17-macos-aarch64-none/lib/python3.10/asyncio/events.py:80, in _run()
     79 try:
---> 80     self._context.run(self._callback, *self._args)
     81 except (SystemExit, KeyboardInterrupt):

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:545, in dispatch_queue()
    544 try:
--> 545     await self.process_one()
    546 except Exception:

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:534, in process_one()
    533         return
--> 534 await dispatch(*args)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:437, in dispatch_shell()
    436     if inspect.isawaitable(result):
--> 437         await result
    438 except Exception:

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/ipkernel.py:362, in execute_request()
    361 self._associate_new_top_level_threads_with(parent_header)
--> 362 await super().execute_request(stream, ident, parent)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/kernelbase.py:778, in execute_request()
    777 if inspect.isawaitable(reply_content):
--> 778     reply_content = await reply_content
    780 # Flush output before sending the reply.

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/ipkernel.py:449, in do_execute()
    448 if accepts_params["cell_id"]:
--> 449     res = shell.run_cell(
    450         code,
    451         store_history=store_history,
    452         silent=silent,
    453         cell_id=cell_id,
    454     )
    455 else:

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/ipykernel/zmqshell.py:549, in run_cell()
    548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3077, in run_cell()
   3076 try:
-> 3077     result = self._run_cell(
   3078         raw_cell, store_history, silent, shell_futures, cell_id
   3079     )
   3080 finally:

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3132, in _run_cell()
   3131 try:
-> 3132     result = runner(coro)
   3133 except BaseException as e:

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/IPython/core/async_helpers.py:128, in _pseudo_sync_runner()
    127 try:
--> 128     coro.send(None)
    129 except StopIteration as exc:

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3336, in run_cell_async()
   3333 interactivity = "none" if silent else self.ast_node_interactivity
-> 3336 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3337        interactivity=interactivity, compiler=compiler, result=result)
   3339 self.last_execution_succeeded = not has_raised

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3519, in run_ast_nodes()
   3518     asy = compare(code)
-> 3519 if await self.run_code(code, result, async_=asy):
   3520     return True

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3579, in run_code()
   3578     else:
-> 3579         exec(code_obj, self.user_global_ns, self.user_ns)
   3580 finally:
   3581     # Reset our crash handler in place

Cell In[41], line 2
      1 results_are = pd.DataFrame(
----> 2     [
      3         asymptotic_variance(n, k)
      4         for n, k in tqdm(zip(ns_are, keys_are), total=len(ns_are))
      5     ]
      6 )
      8 results_are.to_csv(here("data/figures/are_meis_cem_ssms.csv"), index=False)

Cell In[41], line 3, in <listcomp>()
      1 results_are = pd.DataFrame(
      2     [
----> 3         asymptotic_variance(n, k)
      4         for n, k in tqdm(zip(ns_are, keys_are), total=len(ns_are))
      5     ]
      6 )
      8 results_are.to_csv(here("data/figures/are_meis_cem_ssms.csv"), index=False)

Cell In[27], line 37, in asymptotic_variance()
     35 sks_cem = sks[M:]
---> 37 logdet_cem = asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M=len(sks_cem))
     38 logdet_meis = asymptotic_det_meis(
     39     Y, pgssm, prop_la, N_iter, N_samples, key, M=len(sks_meis)
     40 )

Cell In[27], line 16, in asymptotic_det_cem()
     15 key, *subkeys = jrn.split(key, 1 + M)
---> 16 proposals = [CEM(pgssm, Y, N_samples, sk_cem, N_iter)[0] for sk_cem in subkeys]
     17 modes = jnp.array([proposal.mean[:, 0] for proposal in proposals])

Cell In[27], line 16, in <listcomp>()
     15 key, *subkeys = jrn.split(key, 1 + M)
---> 16 proposals = [CEM(pgssm, Y, N_samples, sk_cem, N_iter)[0] for sk_cem in subkeys]
     17 modes = jnp.array([proposal.mean[:, 0] for proposal in proposals])

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/isssm/ce_method.py:222, in cross_entropy_method()
    220     return new_proposal, log_w
--> 222 final_proposal, log_w = fori_loop(
    223     0, n_iter, _iteration, (initial, jnp.empty(4 * N))
    224 )
    226 return final_proposal, log_w

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/isssm/ce_method.py:202, in _iteration()
    200 model_log_weights = partial(log_weight_cem, y=y, model=model, proposal=proposal)
--> 202 samples = simulate_cem(proposal, N, subkey_crn)
    204 _N, np1, m = samples.shape

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/isssm/ce_method.py:103, in simulate_cem()
    102 l_samples = location_antithetic(samples, mean)
--> 103 s_samples = scale_antithethic(u, samples, mean)
    104 ls_samples = scale_antithethic(u, l_samples, mean)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/isssm/util.py:87, in scale_antithethic()
     86 c = jnp.linalg.norm(u, axis=1) ** 2
---> 87 c_prime = chi_dist.quantile(1.0 - chi_dist.cdf(c))
     89 return mean[None] + jnp.sqrt(c_prime / c)[:, None, None] * (samples - mean[None])

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:1573, in quantile()
   1556 """Quantile function. Aka 'inverse cdf' or 'percent point function'.
   1557 
   1558 Given random variable `X` and `p in [0, 1]`, the `quantile` is:
   (...)
   1571     values of type `self.dtype`.
   1572 """
-> 1573 return self._call_quantile(value, name, **kwargs)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:1553, in _call_quantile()
   1547   value = distribution_util.with_dependencies([
   1548       assert_util.assert_less_equal(value, tf.cast(1, value.dtype),
   1549                                     message='`value` must be <= 1'),
   1550       assert_util.assert_greater_equal(value, tf.cast(0, value.dtype),
   1551                                        message='`value` must be >= 0')
   1552   ], value)
-> 1553 return self._quantile(value, **kwargs)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/chi2.py:139, in _quantile()
    138 def _quantile(self, p):
--> 139   return 2. * special.igammainv(0.5 * self.df, p)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/math/special.py:1573, in igammainv()
   1572 p = tf.convert_to_tensor(p, dtype=dtype)
-> 1573 return _igammainv_custom_gradient(a, p)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/math/special.py:1546, in _igammainv_custom_gradient()
   1541 @tfp_custom_gradient.custom_gradient(
   1542     vjp_fwd=_igammainv_fwd,
   1543     vjp_bwd=_igammainv_bwd,
   1544     jvp_fn=_igammainv_jvp)
   1545 def _igammainv_custom_gradient(a, p):
-> 1546   return _shared_igammainv_computation(a, p, is_igammainv=True)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/math/special.py:1470, in _shared_igammainv_computation()
   1466 factorial = tf.math.exp(a * tf.math.log(x) - x - tf.math.lgamma(a))
   1468 f_over_der = tf.where(
   1469     ((p <= 0.9) & is_igammainv) | ((q > 0.9) & (not is_igammainv)),
-> 1470     (tf.math.igamma(a, x) - p) * x / factorial,
   1471     -(tf.math.igammac(a, x) - q) * x / factorial)
   1472 second_der_over_der = -1. + (a - 1.) / x

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_math.py:656, in <lambda>()
    650 greater_equal = utils.copy_docstring(
    651     'tf.math.greater_equal',
    652     lambda x, y, name=None: np.greater_equal(x, y))
    654 igamma = utils.copy_docstring(
    655     'tf.math.igamma',
--> 656     lambda a, x, name=None: scipy_special.gammainc(a, x))
    658 igammac = utils.copy_docstring(
    659     'tf.math.igammac',
    660     lambda a, x, name=None: scipy_special.gammaincc(a, x))

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/scipy/special.py:320, in gammainc()
    319 a, x = promote_args_inexact("gammainc", a, x)
--> 320 return lax.igamma(a, x)

JaxStackTraceBeforeTransformation: KeyboardInterrupt

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

KeyboardInterrupt                         Traceback (most recent call last)
Cell In[41], line 2
      1 results_are = pd.DataFrame(
----> 2     [
      3         asymptotic_variance(n, k)
      4         for n, k in tqdm(zip(ns_are, keys_are), total=len(ns_are))
      5     ]
      6 )
      8 results_are.to_csv(here("data/figures/are_meis_cem_ssms.csv"), index=False)
      9 results_are

Cell In[41], line 3, in <listcomp>(.0)
      1 results_are = pd.DataFrame(
      2     [
----> 3         asymptotic_variance(n, k)
      4         for n, k in tqdm(zip(ns_are, keys_are), total=len(ns_are))
      5     ]
      6 )
      8 results_are.to_csv(here("data/figures/are_meis_cem_ssms.csv"), index=False)
      9 results_are

Cell In[27], line 37, in asymptotic_variance(n, key)
     34 sks_meis = sks[:M]
     35 sks_cem = sks[M:]
---> 37 logdet_cem = asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M=len(sks_cem))
     38 logdet_meis = asymptotic_det_meis(
     39     Y, pgssm, prop_la, N_iter, N_samples, key, M=len(sks_meis)
     40 )
     42 result = pd.Series(
     43     {
     44         "n": n,
   (...)
     50     }
     51 )

Cell In[27], line 16, in asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M)
     14 def asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M: int):
     15     key, *subkeys = jrn.split(key, 1 + M)
---> 16     proposals = [CEM(pgssm, Y, N_samples, sk_cem, N_iter)[0] for sk_cem in subkeys]
     17     modes = jnp.array([proposal.mean[:, 0] for proposal in proposals])
     18     cov = jnp.cov(modes, rowvar=False) * N_samples

Cell In[27], line 16, in <listcomp>(.0)
     14 def asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M: int):
     15     key, *subkeys = jrn.split(key, 1 + M)
---> 16     proposals = [CEM(pgssm, Y, N_samples, sk_cem, N_iter)[0] for sk_cem in subkeys]
     17     modes = jnp.array([proposal.mean[:, 0] for proposal in proposals])
     18     cov = jnp.cov(modes, rowvar=False) * N_samples

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/isssm/ce_method.py:222, in cross_entropy_method(model, y, N, key, n_iter)
    218     new_proposal = proposal_from_moments(mean, consecutive_covs)
    220     return new_proposal, log_w
--> 222 final_proposal, log_w = fori_loop(
    223     0, n_iter, _iteration, (initial, jnp.empty(4 * N))
    224 )
    226 return final_proposal, log_w

    [... skipping hidden 1 frame]

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:2386, in fori_loop(lower, upper, body_fun, init_val, unroll)
   2384   scan_body = _fori_scan_body_fun(body_fun)
   2385   api_util.save_wrapped_fun_sourceinfo(scan_body, body_fun)
-> 2386   (_, result), _ = scan(
   2387       scan_body,
   2388       (lower_, init_val),
   2389       None,
   2390       length=length,
   2391       unroll=unroll,
   2392   )
   2393   return result
   2394 if unroll is not None and unroll is not False and unroll != 1:

    [... skipping hidden 1 frame]

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:355, in scan(f, init, xs, length, reverse, unroll, _split_transpose)
    352   consts = [*new_consts, *consts]
    353   num_carry -= len(new_consts)
--> 355 out = scan_p.bind(*consts, *in_flat,
    356                   reverse=reverse, length=length, jaxpr=jaxpr,
    357                   num_consts=len(consts), num_carry=num_carry,
    358                   linear=(False,) * (len(consts) + len(in_flat)),
    359                   unroll=unroll, _split_transpose=_split_transpose)
    361 if any(move_to_const):
    362   out = pe.merge_lists(move_to_const + [False] * num_ys, out, new_consts)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:531, in Primitive.bind(self, *args, **params)
    529 def bind(self, *args, **params):
    530   args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 531   return self._true_bind(*args, **params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:551, in Primitive._true_bind(self, *args, **params)
    549 trace_ctx.set_trace(eval_trace)
    550 try:
--> 551   return self.bind_with_trace(prev_trace, args, params)
    552 finally:
    553   trace_ctx.set_trace(prev_trace)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:556, in Primitive.bind_with_trace(self, trace, args, params)
    555 def bind_with_trace(self, trace, args, params):
--> 556   return trace.process_primitive(self, args, params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:1060, in EvalTrace.process_primitive(self, primitive, args, params)
   1058 args = map(full_lower, args)
   1059 check_eval_args(args)
-> 1060 return primitive.impl(*args, **params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/dispatch.py:88, in apply_primitive(prim, *args, **params)
     86 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     87 try:
---> 88   outs = fun(*args)
     89 finally:
     90   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 1 frame]

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:334, in _cpp_pjit.<locals>.cache_miss(*args, **kwargs)
    329 if config.no_tracing.value:
    330   raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
    331                      "`jit`, but 'no_tracing' is set")
    333 (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, box_data,
--> 334  executable, pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
    336 maybe_fastpath_data = _get_fastpath_data(
    337     executable, out_tree, args_flat, out_flat, attrs_tracked, box_data,
    338     jaxpr.effects, jaxpr.consts, jit_info.abstracted_axes, pgle_profiler)
    340 return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:195, in _python_pjit_helper(fun, jit_info, *args, **kwargs)
    193   args_flat = map(core.full_lower, args_flat)
    194   core.check_eval_args(args_flat)
--> 195   out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
    196 else:
    197   out_flat = pjit_p.bind(*args_flat, **p.params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:1853, in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, *args)
   1850 compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items())
   1851 # Passing mutable PGLE profile here since it should be extracted by JAXPR to
   1852 # initialize the fdo_profile compile option.
-> 1853 compiled = _resolve_and_lower(
   1854     args, jaxpr=jaxpr, in_shardings=in_shardings,
   1855     out_shardings=out_shardings, in_layouts=in_layouts,
   1856     out_layouts=out_layouts, donated_invars=donated_invars,
   1857     ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused,
   1858     inline=inline, lowering_platforms=None,
   1859     lowering_parameters=mlir.LoweringParameters(),
   1860     pgle_profiler=pgle_profiler,
   1861     compiler_options_kvs=compiler_options_kvs,
   1862 ).compile()
   1864 # This check is expensive so only do it if enable_checks is on.
   1865 if compiled._auto_spmd_lowering and config.enable_checks.value:

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:1820, in _resolve_and_lower(args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, lowering_platforms, lowering_parameters, pgle_profiler, compiler_options_kvs)
   1817 in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings,
   1818                                  jaxpr.in_avals)
   1819 out_layouts = _resolve_out_layouts(out_layouts, out_shardings, jaxpr.out_avals)
-> 1820 return _pjit_lower(
   1821     jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
   1822     donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs,
   1823     lowering_platforms=lowering_platforms,
   1824     lowering_parameters=lowering_parameters,
   1825     pgle_profiler=pgle_profiler)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:1953, in _pjit_lower(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, lowering_platforms, lowering_parameters, pgle_profiler)
   1936 def _pjit_lower(
   1937     jaxpr: core.ClosedJaxpr,
   1938     in_shardings,
   (...)
   1950     lowering_parameters: mlir.LoweringParameters,
   1951     pgle_profiler: profiler.PGLEProfiler | None):
   1952   util.test_event("pjit_lower")
-> 1953   return pxla.lower_sharding_computation(
   1954       jaxpr, 'jit', name, in_shardings, out_shardings,
   1955       in_layouts, out_layouts, tuple(donated_invars),
   1956       keep_unused=keep_unused, context_mesh=ctx_mesh,
   1957       compiler_options_kvs=compiler_options_kvs,
   1958       lowering_platforms=lowering_platforms,
   1959       lowering_parameters=lowering_parameters,
   1960       pgle_profiler=pgle_profiler)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/profiler.py:354, in annotate_function.<locals>.wrapper(*args, **kwargs)
    351 @wraps(func)
    352 def wrapper(*args, **kwargs):
    353   with TraceAnnotation(name, **decorator_kwargs):
--> 354     return func(*args, **kwargs)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:2378, in lower_sharding_computation(closed_jaxpr, api_name, fun_name, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, keep_unused, context_mesh, compiler_options_kvs, lowering_platforms, lowering_parameters, pgle_profiler)
   2372 semantic_in_shardings = SemanticallyEqualShardings(
   2373     in_shardings, global_in_avals)
   2374 semantic_out_shardings = SemanticallyEqualShardings(
   2375     out_shardings, global_out_avals)
   2377 (module, keepalive, host_callbacks, unordered_effects, ordered_effects,
-> 2378  nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
   2379      closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
   2380      semantic_out_shardings, in_layouts, out_layouts, num_devices,
   2381      tuple(da_object) if prim_requires_devices else None,  # type: ignore[arg-type]
   2382      donated_invars, name_stack, all_default_mem_kind, inout_aliases,
   2383      propagated_out_mem_kinds, platforms,
   2384      lowering_parameters=lowering_parameters,
   2385      abstract_mesh=abstract_mesh)
   2387 # backend and device_assignment is passed through to MeshExecutable because
   2388 # if keep_unused=False and all in_shardings are pruned, then there is no way
   2389 # to get the device_assignment and backend. So pass it to MeshExecutable
   2390 # because we calculate the device_assignment and backend before in_shardings,
   2391 # etc are pruned.
   2392 return MeshComputation(
   2393     str(name_stack),
   2394     module,
   (...)
   2421     intermediate_shardings=unique_intermediate_shardings,
   2422     context_mesh=context_mesh)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1968, in _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, in_layouts, out_layouts, num_devices, device_assignment, donated_invars, name_stack, all_default_mem_kind, inout_aliases, propagated_out_mem_kinds, platforms, lowering_parameters, abstract_mesh)
   1964 ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
   1965 with dispatch.log_elapsed_time(
   1966       "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time:.9f} sec",
   1967       fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
-> 1968   lowering_result = mlir.lower_jaxpr_to_module(
   1969       module_name,
   1970       closed_jaxpr,
   1971       ordered_effects=ordered_effects,
   1972       backend=backend,
   1973       platforms=platforms,
   1974       axis_context=axis_ctx,
   1975       name_stack=name_stack,
   1976       donated_args=donated_invars,
   1977       replicated_args=replicated_args,
   1978       arg_shardings=in_mlir_shardings,
   1979       result_shardings=out_mlir_shardings,
   1980       in_layouts=in_layouts,
   1981       out_layouts=out_layouts,
   1982       arg_names=jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)),
   1983       result_names=jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)),
   1984       num_replicas=nreps,
   1985       num_partitions=num_partitions,
   1986       all_default_mem_kind=all_default_mem_kind,
   1987       input_output_aliases=inout_aliases,
   1988       propagated_out_mem_kinds=propagated_out_mem_kinds,
   1989       lowering_parameters=lowering_parameters)
   1990 tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
   1991 unordered_effects = list(
   1992     effects.ordered_effects.filter_not_in(closed_jaxpr.effects))

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1271, in lower_jaxpr_to_module(***failed resolving arguments***)
   1269   attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
   1270   attrs["mhlo.num_partitions"] = i32_attr(num_partitions)
-> 1271   lower_jaxpr_to_fun(
   1272       ctx, "main", jaxpr, ordered_effects,
   1273       name_stack=name_stack,
   1274       public=True,
   1275       replicated_args=replicated_args,
   1276       arg_shardings=arg_shardings,
   1277       result_shardings=result_shardings,
   1278       input_output_aliases=input_output_aliases,
   1279       xla_donated_args=xla_donated_args,
   1280       arg_names=arg_names,
   1281       result_names=result_names,
   1282       arg_memory_kinds=arg_memory_kinds,
   1283       result_memory_kinds=result_memory_kinds,
   1284       arg_layouts=in_layouts,
   1285       result_layouts=out_layouts,
   1286       propagated_out_mem_kinds=propagated_out_mem_kinds)
   1288 try:
   1289   if not ctx.module.operation.verify():

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1763, in lower_jaxpr_to_fun(ctx, name, jaxpr, effects, name_stack, public, replicated_args, arg_shardings, result_shardings, use_sharding_annotations, input_output_aliases, xla_donated_args, api_name, arg_names, result_names, arg_memory_kinds, result_memory_kinds, arg_layouts, result_layouts, propagated_out_mem_kinds)
   1761   callee_name_stack = name_stack
   1762 consts = [ir_constant(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
-> 1763 out_vals, tokens_out = jaxpr_subcomp(
   1764     ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
   1765     consts, *args, dim_var_values=dim_var_values)
   1766 outs: list[IrValues] = []
   1767 for eff in effects:

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2040, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args)
   2037   rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
   2039 assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes)
-> 2040 ans = lower_per_platform(rule_ctx, str(eqn.primitive),
   2041                          platform_rules, default_rule,
   2042                          eqn.effects,
   2043                          *in_nodes, **eqn.params)
   2045 if effects:
   2046   # If there were ordered effects in the primitive, there should be output
   2047   # tokens we need for subsequent ordered effects.
   2048   tokens_out = rule_ctx.tokens_out

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2162, in lower_per_platform(ctx, description, platform_rules, default_rule, effects, *rule_args, **rule_kwargs)
   2160 # If there is a single rule left just apply the rule, without conditionals.
   2161 if len(kept_rules) == 1:
-> 2162   output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
   2163   foreach(
   2164       lambda o: wrap_compute_type_in_place(ctx, o.owner),
   2165       filter(_is_not_block_argument, flatten_ir_values(output)),
   2166   )
   2167   foreach(
   2168       lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
   2169       flatten_ir_values(output),
   2170   )

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2278, in lower_fun.<locals>.f_lowered(ctx, *args, **params)
   2276 else:
   2277   sub_context = ctx.module_context
-> 2278 out, tokens = jaxpr_subcomp(
   2279     sub_context, jaxpr, ctx.name_stack, ctx.tokens_in,
   2280     _ir_consts(consts), *args,
   2281     dim_var_values=ctx.dim_var_values)
   2282 ctx.set_tokens_out(tokens)
   2283 return out

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2040, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args)
   2037   rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
   2039 assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes)
-> 2040 ans = lower_per_platform(rule_ctx, str(eqn.primitive),
   2041                          platform_rules, default_rule,
   2042                          eqn.effects,
   2043                          *in_nodes, **eqn.params)
   2045 if effects:
   2046   # If there were ordered effects in the primitive, there should be output
   2047   # tokens we need for subsequent ordered effects.
   2048   tokens_out = rule_ctx.tokens_out

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2162, in lower_per_platform(ctx, description, platform_rules, default_rule, effects, *rule_args, **rule_kwargs)
   2160 # If there is a single rule left just apply the rule, without conditionals.
   2161 if len(kept_rules) == 1:
-> 2162   output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
   2163   foreach(
   2164       lambda o: wrap_compute_type_in_place(ctx, o.owner),
   2165       filter(_is_not_block_argument, flatten_ir_values(output)),
   2166   )
   2167   foreach(
   2168       lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
   2169       flatten_ir_values(output),
   2170   )

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:2101, in _while_lowering(ctx, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts, *args)
   2098 body_name_stack = name_stack.extend('body')
   2099 body_consts = [mlir.ir_constant(xla.canonicalize_dtype(x))
   2100                for x in body_jaxpr.consts]
-> 2101 new_z, tokens_out = mlir.jaxpr_subcomp(
   2102     ctx.module_context, body_jaxpr.jaxpr, body_name_stack,
   2103     tokens_in, body_consts, *(y + z), dim_var_values=ctx.dim_var_values)
   2104 out_tokens = [tokens_out.get(eff) for eff in body_effects]
   2105 if batched:

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2040, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args)
   2037   rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
   2039 assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes)
-> 2040 ans = lower_per_platform(rule_ctx, str(eqn.primitive),
   2041                          platform_rules, default_rule,
   2042                          eqn.effects,
   2043                          *in_nodes, **eqn.params)
   2045 if effects:
   2046   # If there were ordered effects in the primitive, there should be output
   2047   # tokens we need for subsequent ordered effects.
   2048   tokens_out = rule_ctx.tokens_out

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2162, in lower_per_platform(ctx, description, platform_rules, default_rule, effects, *rule_args, **rule_kwargs)
   2160 # If there is a single rule left just apply the rule, without conditionals.
   2161 if len(kept_rules) == 1:
-> 2162   output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
   2163   foreach(
   2164       lambda o: wrap_compute_type_in_place(ctx, o.owner),
   2165       filter(_is_not_block_argument, flatten_ir_values(output)),
   2166   )
   2167   foreach(
   2168       lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
   2169       flatten_ir_values(output),
   2170   )

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2383, in core_call_lowering(ctx, name, backend, call_jaxpr, *args)
   2381 def core_call_lowering(ctx: LoweringRuleContext,
   2382                        *args, name, backend=None, call_jaxpr):
-> 2383   out_nodes, tokens = call_lowering(
   2384       name, ctx.name_stack, call_jaxpr, backend, ctx.module_context,
   2385       ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args,
   2386       dim_var_values=ctx.dim_var_values)
   2387   ctx.set_tokens_out(tokens)
   2388   return out_nodes

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2366, in call_lowering(***failed resolving arguments***)
   2360 def call_lowering(fn_name, name_stack, call_jaxpr, backend,
   2361                   ctx: ModuleContext, avals_in,
   2362                   avals_out, tokens_in, *args,
   2363                   dim_var_values: Sequence[ir.Value],
   2364                   arg_names=None, result_names=None):
   2365   del avals_in
-> 2366   func_op, output_types, effects = lower_called_computation(
   2367       fn_name, name_stack, call_jaxpr, ctx, avals_out, tokens_in,
   2368       backend=backend, arg_names=arg_names, result_names=result_names)
   2369   symbol_name = func_op.name.value
   2370   flat_output_types = flatten_ir_types(output_types)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2348, in lower_called_computation(fn_name, name_stack, call_jaxpr, ctx, avals_out, tokens_in, backend, arg_names, result_names)
   2346 output_types = map(aval_to_ir_type, avals_out)
   2347 output_types = [token_type()] * len(effects) + output_types
-> 2348 func_op = _lower_jaxpr_to_fun_cached(
   2349     ctx,
   2350     fn_name,
   2351     call_jaxpr,
   2352     effects,
   2353     name_stack,
   2354     arg_names=arg_names,
   2355     result_names=result_names,
   2356 )
   2357 return func_op, output_types, effects

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2297, in _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, name_stack, arg_names, result_names)
   2295 except KeyError:
   2296   num_callbacks = len(ctx.host_callbacks)
-> 2297   func_op = lower_jaxpr_to_fun(
   2298       ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names,
   2299       result_names=result_names)
   2301   # If this Jaxpr includes callbacks, we can't cache the lowering because
   2302   # on TPU every callback must have a globally unique channel, but the
   2303   # channel gets assigned during lowering.
   2304   has_callbacks = len(ctx.host_callbacks) > num_callbacks

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1763, in lower_jaxpr_to_fun(ctx, name, jaxpr, effects, name_stack, public, replicated_args, arg_shardings, result_shardings, use_sharding_annotations, input_output_aliases, xla_donated_args, api_name, arg_names, result_names, arg_memory_kinds, result_memory_kinds, arg_layouts, result_layouts, propagated_out_mem_kinds)
   1761   callee_name_stack = name_stack
   1762 consts = [ir_constant(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
-> 1763 out_vals, tokens_out = jaxpr_subcomp(
   1764     ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
   1765     consts, *args, dim_var_values=dim_var_values)
   1766 outs: list[IrValues] = []
   1767 for eff in effects:

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2040, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args)
   2037   rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
   2039 assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes)
-> 2040 ans = lower_per_platform(rule_ctx, str(eqn.primitive),
   2041                          platform_rules, default_rule,
   2042                          eqn.effects,
   2043                          *in_nodes, **eqn.params)
   2045 if effects:
   2046   # If there were ordered effects in the primitive, there should be output
   2047   # tokens we need for subsequent ordered effects.
   2048   tokens_out = rule_ctx.tokens_out

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2162, in lower_per_platform(ctx, description, platform_rules, default_rule, effects, *rule_args, **rule_kwargs)
   2160 # If there is a single rule left just apply the rule, without conditionals.
   2161 if len(kept_rules) == 1:
-> 2162   output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
   2163   foreach(
   2164       lambda o: wrap_compute_type_in_place(ctx, o.owner),
   2165       filter(_is_not_block_argument, flatten_ir_values(output)),
   2166   )
   2167   foreach(
   2168       lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
   2169       flatten_ir_values(output),
   2170   )

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/custom_derivatives.py:430, in _custom_jvp_vjp_call_lowering(ctx, call_jaxpr, *args, **_)
    428 def _custom_jvp_vjp_call_lowering(ctx, *args, call_jaxpr, **_):
    429   consts = mlir._ir_consts(call_jaxpr.consts)
--> 430   out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
    431                                    ctx.name_stack, ctx.tokens_in, consts,
    432                                    *args, dim_var_values=ctx.dim_var_values)
    433   ctx.set_tokens_out(tokens)
    434   return out

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2040, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args)
   2037   rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
   2039 assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes)
-> 2040 ans = lower_per_platform(rule_ctx, str(eqn.primitive),
   2041                          platform_rules, default_rule,
   2042                          eqn.effects,
   2043                          *in_nodes, **eqn.params)
   2045 if effects:
   2046   # If there were ordered effects in the primitive, there should be output
   2047   # tokens we need for subsequent ordered effects.
   2048   tokens_out = rule_ctx.tokens_out

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2162, in lower_per_platform(ctx, description, platform_rules, default_rule, effects, *rule_args, **rule_kwargs)
   2160 # If there is a single rule left just apply the rule, without conditionals.
   2161 if len(kept_rules) == 1:
-> 2162   output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
   2163   foreach(
   2164       lambda o: wrap_compute_type_in_place(ctx, o.owner),
   2165       filter(_is_not_block_argument, flatten_ir_values(output)),
   2166   )
   2167   foreach(
   2168       lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
   2169       flatten_ir_values(output),
   2170   )

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2278, in lower_fun.<locals>.f_lowered(ctx, *args, **params)
   2276 else:
   2277   sub_context = ctx.module_context
-> 2278 out, tokens = jaxpr_subcomp(
   2279     sub_context, jaxpr, ctx.name_stack, ctx.tokens_in,
   2280     _ir_consts(consts), *args,
   2281     dim_var_values=ctx.dim_var_values)
   2282 ctx.set_tokens_out(tokens)
   2283 return out

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:2006, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args)
   2003 in_nodes = map(read, eqn.invars)
   2004 source_info = eqn.source_info.replace(
   2005     name_stack=name_stack + eqn.source_info.name_stack)
-> 2006 loc = _source_info_to_location(ctx, eqn.primitive, source_info)
   2007 with (source_info_util.user_context(eqn.source_info.traceback), loc,
   2008       eqn.ctx.manager):
   2009   override_rule = get_override_lowering_rule(eqn.primitive)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:516, in _source_info_to_location(ctx, primitive, source_info)
    512   else:
    513     loc = ir.Location.file(get_canonical_source_file(frame.file_name,
    514                                                      ctx.traceback_caches),
    515                            frame.start_line, frame.start_column)
--> 516 loc = ir.Location.name(eqn_str, childLoc=loc)
    517 # TODO(phawkins): also include primitive.name as the operator type.
    518 return loc

KeyboardInterrupt: